Files

145 lines
5.0 KiB
Python
Raw Permalink Normal View History

2025-10-21 11:07:29 +08:00
from ultralytics import YOLO
import cv2
import os
import numpy as np
# 设置类别名称(必须与训练时一致)
CLASS_NAMES = ['ban', 'bag'] # ✅ 确保顺序正确,对应模型的 class_id
COLORS = [(0, 255, 0), (255, 0, 0)] # ban: 绿色, bag: 蓝色
def get_best_angles_per_class(image_path, weight_path, return_degree=False):
"""
输入
image_path: 图像路径
weight_path: YOLO OBB 权重路径
return_degree: 是否返回角度单位否则为弧度
输出
字典{ class_name: best_angle None }
"""
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像:{image_path}")
return {cls: None for cls in CLASS_NAMES}
model = YOLO(weight_path)
results = model(img, save=False, imgsz=640, conf=0.15, task='obb')
result = results[0]
boxes = result.obb
if boxes is None or len(boxes) == 0:
print("⚠️ 未检测到任何目标。")
return {cls: None for cls in CLASS_NAMES}
# 提取数据
xywhr = boxes.xywhr.cpu().numpy() # (N, 5) -> cx, cy, w, h, r (弧度)
confs = boxes.conf.cpu().numpy() # (N,)
class_ids = boxes.cls.cpu().numpy().astype(int) # (N,)
# 初始化结果字典
best_angles = {cls: None for cls in CLASS_NAMES}
# 对每个类别找置信度最高的框
for class_id, class_name in enumerate(CLASS_NAMES):
mask = (class_ids == class_id)
if not np.any(mask):
print(f"🟡 未检测到类别: {class_name}")
continue
# 找该类别中置信度最高的
idx_in_class = np.argmax(confs[mask])
global_idx = np.where(mask)[0][idx_in_class]
angle_rad = xywhr[global_idx][4]
best_angles[class_name] = np.degrees(angle_rad) if return_degree else angle_rad
return best_angles
def save_obb_visual(image_path, weight_path, save_path):
"""
输入
image_path: 图像路径
weight_path: YOLO权重路径
save_path: 保存带标注图像路径
功能
检测所有 OBB绘制框类别名旋转角度保存图片
"""
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像:{image_path}")
return
model = YOLO(weight_path)
results = model(img, save=False, imgsz=640, conf=0.05, task='obb')
result = results[0]
boxes = result.obb
if boxes is None or len(boxes) == 0:
print("⚠️ 未检测到任何目标。")
# 仍保存原图
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, img)
return
# 提取信息
xywhr = boxes.xywhr.cpu().numpy()
confs = boxes.conf.cpu().numpy()
class_ids = boxes.cls.cpu().numpy().astype(int)
# 绘制
annotated_img = img.copy()
for i in range(len(boxes)):
cx, cy, w, h, r = xywhr[i]
angle_deg = np.degrees(r)
class_id = class_ids[i]
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"cls{class_id}"
conf = confs[i]
color = COLORS[class_id % len(COLORS)] if class_id < len(CLASS_NAMES) else (128, 128, 128)
# 绘制旋转框
rect = ((cx, cy), (w, h), angle_deg)
box_pts = cv2.boxPoints(rect).astype(int)
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=color, thickness=2)
# 标注文本:类别 + 置信度 + 角度
text = f"{class_name} {conf:.2f} {angle_deg:.1f}°"
font_scale = 0.7
thickness = 2
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
# 文本背景
cv2.rectangle(annotated_img,
(box_pts[0][0], box_pts[0][1] - text_size[1] - 8),
(box_pts[0][0] + text_size[0], box_pts[0][1] + 2),
color, -1)
# 文本
cv2.putText(annotated_img, text,
(box_pts[0][0], box_pts[0][1] - 5),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
# 保存
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, annotated_img)
print(f"✅ 检测结果已保存至: {save_path}")
# ===============================
# 示例调用
# ===============================
if __name__ == "__main__":
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
image = r"/home/hx/yolo/output_masks/2.jpg"
save_path = "./inference_results/visualized_2.jpg"
# 获取每个类别的最佳角度(以度为单位)
angles_deg = get_best_angles_per_class(image, weight, return_degree=True)
print("\n🎯 各类别最佳旋转角度(度):")
for cls_name, angle in angles_deg.items():
if angle is not None:
print(f" {cls_name}: {angle:.2f}°")
else:
print(f" {cls_name}: 未检测到")
# 可视化所有检测结果
save_obb_visual(image, weight, save_path)