Files
zjsh_yolov11/ailai_obb/angle.py
2025-09-15 15:35:19 +08:00

106 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import os
import numpy as np
def get_best_obb_angle(image_path, weight_path, return_degree=False):
"""
输入:
image_path: 图像路径
weight_path: YOLO权重路径
return_degree: 是否返回角度单位为度,默认 False返回弧度
输出:
置信度最高目标的旋转角
如果未检测到目标返回 None
"""
# 读取图像
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像:{image_path}")
return None
# 加载模型并预测
model = YOLO(weight_path)
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
result = results[0]
boxes = result.obb
if not boxes:
print("⚠️ 未检测到目标。")
return None
# 取置信度最高框的旋转角
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
r = best_box.xywhr.cpu().numpy()[0][4] # 弧度
if return_degree:
return np.degrees(r)
else:
return r
def save_obb_visual(image_path, weight_path, save_path):
"""
输入:
image_path: 图像路径
weight_path: YOLO权重路径
save_path: 保存带角度标注图像路径
功能:
检测 OBB 并标注置信度最高框旋转角度,保存图片
"""
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像:{image_path}")
return
model = YOLO(weight_path)
results = model(img, save=False, imgsz=640, conf=0.15, mode='obb')
result = results[0]
boxes = result.obb
if not boxes:
print("⚠️ 未检测到目标。")
return
best_box = max(boxes, key=lambda x: x.conf.cpu().numpy()[0])
cx, cy, w, h, r = best_box.xywhr.cpu().numpy()[0]
angle_deg = np.degrees(r)
# 绘制 OBB
annotated_img = img.copy()
rect = ((cx, cy), (w, h), angle_deg)
box_pts = cv2.boxPoints(rect).astype(int)
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=(0, 255, 0), thickness=2)
# 标注角度
text = f"{angle_deg:.1f}°"
font_scale = max(0.5, min(w, h)/100)
thickness = 2
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
text_x = int(cx - text_size[0]/2)
text_y = int(cy + text_size[1]/2)
cv2.putText(annotated_img, text, (text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), thickness)
# 保存
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, annotated_img)
print(f"✅ 检测结果已保存至: {save_path}")
# ===============================
# 示例调用
# ===============================
if __name__ == "__main__":
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
image = r"/home/hx/yolo/output_masks/2.jpg"
save_path = "./inference_results/best_detected_2.jpg"
angle_rad = get_best_obb_angle(image, weight)
print(f"旋转角(弧度):{angle_rad:.4f}")
angle_deg = get_best_obb_angle(image, weight, return_degree=True)
print(f"旋转角(度):{angle_deg:.2f}°")
save_obb_visual(image, weight, save_path)