106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
|
|
from ultralytics import YOLO
|
|||
|
|
import cv2
|
|||
|
|
import os
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
def get_best_obb_angle(image_path, weight_path, return_degree=False):
|
|||
|
|
"""
|
|||
|
|
输入:
|
|||
|
|
image_path: 图像路径
|
|||
|
|
weight_path: YOLO权重路径
|
|||
|
|
return_degree: 是否返回角度单位为度,默认 False(返回弧度)
|
|||
|
|
输出:
|
|||
|
|
置信度最高目标的旋转角
|
|||
|
|
如果未检测到目标返回 None
|
|||
|
|
"""
|
|||
|
|
# 读取图像
|
|||
|
|
img = cv2.imread(image_path)
|
|||
|
|
if img is None:
|
|||
|
|
print(f"❌ 无法读取图像:{image_path}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 加载模型并预测
|
|||
|
|
model = YOLO(weight_path)
|
|||
|
|
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
boxes = result.obb
|
|||
|
|
if not boxes:
|
|||
|
|
print("⚠️ 未检测到目标。")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 取置信度最高框的旋转角
|
|||
|
|
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
|
|||
|
|
r = best_box.xywhr.cpu().numpy()[0][4] # 弧度
|
|||
|
|
|
|||
|
|
if return_degree:
|
|||
|
|
return np.degrees(r)
|
|||
|
|
else:
|
|||
|
|
return r
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_obb_visual(image_path, weight_path, save_path):
|
|||
|
|
"""
|
|||
|
|
输入:
|
|||
|
|
image_path: 图像路径
|
|||
|
|
weight_path: YOLO权重路径
|
|||
|
|
save_path: 保存带角度标注图像路径
|
|||
|
|
功能:
|
|||
|
|
检测 OBB 并标注置信度最高框旋转角度,保存图片
|
|||
|
|
"""
|
|||
|
|
img = cv2.imread(image_path)
|
|||
|
|
if img is None:
|
|||
|
|
print(f"❌ 无法读取图像:{image_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
model = YOLO(weight_path)
|
|||
|
|
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
boxes = result.obb
|
|||
|
|
if not boxes:
|
|||
|
|
print("⚠️ 未检测到目标。")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
|
|||
|
|
cx, cy, w, h, r = best_box.xywhr.cpu().numpy()[0]
|
|||
|
|
angle_deg = np.degrees(r)
|
|||
|
|
|
|||
|
|
# 绘制 OBB
|
|||
|
|
annotated_img = img.copy()
|
|||
|
|
rect = ((cx, cy), (w, h), angle_deg)
|
|||
|
|
box_pts = cv2.boxPoints(rect).astype(int)
|
|||
|
|
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=(0, 255, 0), thickness=2)
|
|||
|
|
|
|||
|
|
# 标注角度
|
|||
|
|
text = f"{angle_deg:.1f}°"
|
|||
|
|
font_scale = max(0.5, min(w, h)/100)
|
|||
|
|
thickness = 2
|
|||
|
|
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
|||
|
|
text_x = int(cx - text_size[0]/2)
|
|||
|
|
text_y = int(cy + text_size[1]/2)
|
|||
|
|
cv2.putText(annotated_img, text, (text_x, text_y),
|
|||
|
|
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), thickness)
|
|||
|
|
|
|||
|
|
# 保存
|
|||
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|||
|
|
cv2.imwrite(save_path, annotated_img)
|
|||
|
|
print(f"✅ 检测结果已保存至: {save_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===============================
|
|||
|
|
# 示例调用
|
|||
|
|
# ===============================
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
|
|||
|
|
image = r"/home/hx/yolo/output_masks/2.jpg"
|
|||
|
|
save_path = "./inference_results/best_detected_2.jpg"
|
|||
|
|
|
|||
|
|
angle_rad = get_best_obb_angle(image, weight)
|
|||
|
|
print(f"旋转角(弧度):{angle_rad:.4f}")
|
|||
|
|
|
|||
|
|
angle_deg = get_best_obb_angle(image, weight, return_degree=True)
|
|||
|
|
print(f"旋转角(度):{angle_deg:.2f}°")
|
|||
|
|
|
|||
|
|
save_obb_visual(image, weight, save_path)
|