from ultralytics import YOLO import cv2 import os import numpy as np # 设置类别名称(必须与训练时一致) CLASS_NAMES = ['ban', 'bag'] # ✅ 确保顺序正确,对应模型的 class_id COLORS = [(0, 255, 0), (255, 0, 0)] # ban: 绿色, bag: 蓝色 def get_best_angles_per_class(image_path, weight_path, return_degree=False): """ 输入: image_path: 图像路径 weight_path: YOLO OBB 权重路径 return_degree: 是否返回角度(单位:度),否则为弧度 输出: 字典:{ class_name: best_angle 或 None } """ img = cv2.imread(image_path) if img is None: print(f"❌ 无法读取图像:{image_path}") return {cls: None for cls in CLASS_NAMES} model = YOLO(weight_path) results = model(img, save=False, imgsz=640, conf=0.15, task='obb') result = results[0] boxes = result.obb if boxes is None or len(boxes) == 0: print("⚠️ 未检测到任何目标。") return {cls: None for cls in CLASS_NAMES} # 提取数据 xywhr = boxes.xywhr.cpu().numpy() # (N, 5) -> cx, cy, w, h, r (弧度) confs = boxes.conf.cpu().numpy() # (N,) class_ids = boxes.cls.cpu().numpy().astype(int) # (N,) # 初始化结果字典 best_angles = {cls: None for cls in CLASS_NAMES} # 对每个类别找置信度最高的框 for class_id, class_name in enumerate(CLASS_NAMES): mask = (class_ids == class_id) if not np.any(mask): print(f"🟡 未检测到类别: {class_name}") continue # 找该类别中置信度最高的 idx_in_class = np.argmax(confs[mask]) global_idx = np.where(mask)[0][idx_in_class] angle_rad = xywhr[global_idx][4] best_angles[class_name] = np.degrees(angle_rad) if return_degree else angle_rad return best_angles def save_obb_visual(image_path, weight_path, save_path): """ 输入: image_path: 图像路径 weight_path: YOLO权重路径 save_path: 保存带标注图像路径 功能: 检测所有 OBB,绘制框、类别名、旋转角度,保存图片 """ img = cv2.imread(image_path) if img is None: print(f"❌ 无法读取图像:{image_path}") return model = YOLO(weight_path) results = model(img, save=False, imgsz=640, conf=0.05, task='obb') result = results[0] boxes = result.obb if boxes is None or len(boxes) == 0: print("⚠️ 未检测到任何目标。") # 仍保存原图 os.makedirs(os.path.dirname(save_path), exist_ok=True) cv2.imwrite(save_path, img) return # 提取信息 xywhr = boxes.xywhr.cpu().numpy() confs = boxes.conf.cpu().numpy() class_ids = boxes.cls.cpu().numpy().astype(int) # 绘制 annotated_img = img.copy() for i in range(len(boxes)): cx, cy, w, h, r = xywhr[i] angle_deg = np.degrees(r) class_id = class_ids[i] class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"cls{class_id}" conf = confs[i] color = COLORS[class_id % len(COLORS)] if class_id < len(CLASS_NAMES) else (128, 128, 128) # 绘制旋转框 rect = ((cx, cy), (w, h), angle_deg) box_pts = cv2.boxPoints(rect).astype(int) cv2.polylines(annotated_img, [box_pts], isClosed=True, color=color, thickness=2) # 标注文本:类别 + 置信度 + 角度 text = f"{class_name} {conf:.2f} {angle_deg:.1f}°" font_scale = 0.7 thickness = 2 text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) # 文本背景 cv2.rectangle(annotated_img, (box_pts[0][0], box_pts[0][1] - text_size[1] - 8), (box_pts[0][0] + text_size[0], box_pts[0][1] + 2), color, -1) # 文本 cv2.putText(annotated_img, text, (box_pts[0][0], box_pts[0][1] - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) # 保存 os.makedirs(os.path.dirname(save_path), exist_ok=True) cv2.imwrite(save_path, annotated_img) print(f"✅ 检测结果已保存至: {save_path}") # =============================== # 示例调用 # =============================== if __name__ == "__main__": weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt" image = r"/home/hx/yolo/output_masks/2.jpg" save_path = "./inference_results/visualized_2.jpg" # 获取每个类别的最佳角度(以度为单位) angles_deg = get_best_angles_per_class(image, weight, return_degree=True) print("\n🎯 各类别最佳旋转角度(度):") for cls_name, angle in angles_deg.items(): if angle is not None: print(f" {cls_name}: {angle:.2f}°") else: print(f" {cls_name}: 未检测到") # 可视化所有检测结果 save_obb_visual(image, weight, save_path)