106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
|
||
def get_best_obb_angle(image_path, weight_path, return_degree=False):
|
||
"""
|
||
输入:
|
||
image_path: 图像路径
|
||
weight_path: YOLO权重路径
|
||
return_degree: 是否返回角度单位为度,默认 False(返回弧度)
|
||
输出:
|
||
置信度最高目标的旋转角
|
||
如果未检测到目标返回 None
|
||
"""
|
||
# 读取图像
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图像:{image_path}")
|
||
return None
|
||
|
||
# 加载模型并预测
|
||
model = YOLO(weight_path)
|
||
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
|
||
result = results[0]
|
||
|
||
boxes = result.obb
|
||
if not boxes:
|
||
print("⚠️ 未检测到目标。")
|
||
return None
|
||
|
||
# 取置信度最高框的旋转角
|
||
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
|
||
r = best_box.xywhr.cpu().numpy()[0][4] # 弧度
|
||
|
||
if return_degree:
|
||
return np.degrees(r)
|
||
else:
|
||
return r
|
||
|
||
|
||
def save_obb_visual(image_path, weight_path, save_path):
|
||
"""
|
||
输入:
|
||
image_path: 图像路径
|
||
weight_path: YOLO权重路径
|
||
save_path: 保存带角度标注图像路径
|
||
功能:
|
||
检测 OBB 并标注置信度最高框旋转角度,保存图片
|
||
"""
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图像:{image_path}")
|
||
return
|
||
|
||
model = YOLO(weight_path)
|
||
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
|
||
result = results[0]
|
||
|
||
boxes = result.obb
|
||
if not boxes:
|
||
print("⚠️ 未检测到目标。")
|
||
return
|
||
|
||
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
|
||
cx, cy, w, h, r = best_box.xywhr.cpu().numpy()[0]
|
||
angle_deg = np.degrees(r)
|
||
|
||
# 绘制 OBB
|
||
annotated_img = img.copy()
|
||
rect = ((cx, cy), (w, h), angle_deg)
|
||
box_pts = cv2.boxPoints(rect).astype(int)
|
||
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=(0, 255, 0), thickness=2)
|
||
|
||
# 标注角度
|
||
text = f"{angle_deg:.1f}°"
|
||
font_scale = max(0.5, min(w, h)/100)
|
||
thickness = 2
|
||
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||
text_x = int(cx - text_size[0]/2)
|
||
text_y = int(cy + text_size[1]/2)
|
||
cv2.putText(annotated_img, text, (text_x, text_y),
|
||
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), thickness)
|
||
|
||
# 保存
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
cv2.imwrite(save_path, annotated_img)
|
||
print(f"✅ 检测结果已保存至: {save_path}")
|
||
|
||
|
||
# ===============================
|
||
# 示例调用
|
||
# ===============================
|
||
if __name__ == "__main__":
|
||
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
|
||
image = r"/home/hx/yolo/output_masks/2.jpg"
|
||
save_path = "./inference_results/best_detected_2.jpg"
|
||
|
||
angle_rad = get_best_obb_angle(image, weight)
|
||
print(f"旋转角(弧度):{angle_rad:.4f}")
|
||
|
||
angle_deg = get_best_obb_angle(image, weight, return_degree=True)
|
||
print(f"旋转角(度):{angle_deg:.2f}°")
|
||
|
||
save_obb_visual(image, weight, save_path)
|