145 lines
5.0 KiB
Python
145 lines
5.0 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
|
||
# 设置类别名称(必须与训练时一致)
|
||
CLASS_NAMES = ['ban', 'bag'] # ✅ 确保顺序正确,对应模型的 class_id
|
||
COLORS = [(0, 255, 0), (255, 0, 0)] # ban: 绿色, bag: 蓝色
|
||
|
||
|
||
def get_best_angles_per_class(image_path, weight_path, return_degree=False):
|
||
"""
|
||
输入:
|
||
image_path: 图像路径
|
||
weight_path: YOLO OBB 权重路径
|
||
return_degree: 是否返回角度(单位:度),否则为弧度
|
||
输出:
|
||
字典:{ class_name: best_angle 或 None }
|
||
"""
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图像:{image_path}")
|
||
return {cls: None for cls in CLASS_NAMES}
|
||
|
||
model = YOLO(weight_path)
|
||
results = model(img, save=False, imgsz=640, conf=0.15, task='obb')
|
||
result = results[0]
|
||
|
||
boxes = result.obb
|
||
if boxes is None or len(boxes) == 0:
|
||
print("⚠️ 未检测到任何目标。")
|
||
return {cls: None for cls in CLASS_NAMES}
|
||
|
||
# 提取数据
|
||
xywhr = boxes.xywhr.cpu().numpy() # (N, 5) -> cx, cy, w, h, r (弧度)
|
||
confs = boxes.conf.cpu().numpy() # (N,)
|
||
class_ids = boxes.cls.cpu().numpy().astype(int) # (N,)
|
||
|
||
# 初始化结果字典
|
||
best_angles = {cls: None for cls in CLASS_NAMES}
|
||
|
||
# 对每个类别找置信度最高的框
|
||
for class_id, class_name in enumerate(CLASS_NAMES):
|
||
mask = (class_ids == class_id)
|
||
if not np.any(mask):
|
||
print(f"🟡 未检测到类别: {class_name}")
|
||
continue
|
||
|
||
# 找该类别中置信度最高的
|
||
idx_in_class = np.argmax(confs[mask])
|
||
global_idx = np.where(mask)[0][idx_in_class]
|
||
angle_rad = xywhr[global_idx][4]
|
||
|
||
best_angles[class_name] = np.degrees(angle_rad) if return_degree else angle_rad
|
||
|
||
return best_angles
|
||
|
||
|
||
def save_obb_visual(image_path, weight_path, save_path):
|
||
"""
|
||
输入:
|
||
image_path: 图像路径
|
||
weight_path: YOLO权重路径
|
||
save_path: 保存带标注图像路径
|
||
功能:
|
||
检测所有 OBB,绘制框、类别名、旋转角度,保存图片
|
||
"""
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图像:{image_path}")
|
||
return
|
||
|
||
model = YOLO(weight_path)
|
||
results = model(img, save=False, imgsz=640, conf=0.05, task='obb')
|
||
result = results[0]
|
||
|
||
boxes = result.obb
|
||
if boxes is None or len(boxes) == 0:
|
||
print("⚠️ 未检测到任何目标。")
|
||
# 仍保存原图
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
cv2.imwrite(save_path, img)
|
||
return
|
||
|
||
# 提取信息
|
||
xywhr = boxes.xywhr.cpu().numpy()
|
||
confs = boxes.conf.cpu().numpy()
|
||
class_ids = boxes.cls.cpu().numpy().astype(int)
|
||
|
||
# 绘制
|
||
annotated_img = img.copy()
|
||
for i in range(len(boxes)):
|
||
cx, cy, w, h, r = xywhr[i]
|
||
angle_deg = np.degrees(r)
|
||
class_id = class_ids[i]
|
||
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"cls{class_id}"
|
||
conf = confs[i]
|
||
color = COLORS[class_id % len(COLORS)] if class_id < len(CLASS_NAMES) else (128, 128, 128)
|
||
|
||
# 绘制旋转框
|
||
rect = ((cx, cy), (w, h), angle_deg)
|
||
box_pts = cv2.boxPoints(rect).astype(int)
|
||
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=color, thickness=2)
|
||
|
||
# 标注文本:类别 + 置信度 + 角度
|
||
text = f"{class_name} {conf:.2f} {angle_deg:.1f}°"
|
||
font_scale = 0.7
|
||
thickness = 2
|
||
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||
|
||
# 文本背景
|
||
cv2.rectangle(annotated_img,
|
||
(box_pts[0][0], box_pts[0][1] - text_size[1] - 8),
|
||
(box_pts[0][0] + text_size[0], box_pts[0][1] + 2),
|
||
color, -1)
|
||
# 文本
|
||
cv2.putText(annotated_img, text,
|
||
(box_pts[0][0], box_pts[0][1] - 5),
|
||
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
|
||
|
||
# 保存
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
cv2.imwrite(save_path, annotated_img)
|
||
print(f"✅ 检测结果已保存至: {save_path}")
|
||
|
||
|
||
# ===============================
|
||
# 示例调用
|
||
# ===============================
|
||
if __name__ == "__main__":
|
||
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
|
||
image = r"/home/hx/yolo/output_masks/2.jpg"
|
||
save_path = "./inference_results/visualized_2.jpg"
|
||
|
||
# 获取每个类别的最佳角度(以度为单位)
|
||
angles_deg = get_best_angles_per_class(image, weight, return_degree=True)
|
||
print("\n🎯 各类别最佳旋转角度(度):")
|
||
for cls_name, angle in angles_deg.items():
|
||
if angle is not None:
|
||
print(f" {cls_name}: {angle:.2f}°")
|
||
else:
|
||
print(f" {cls_name}: 未检测到")
|
||
|
||
# 可视化所有检测结果
|
||
save_obb_visual(image, weight, save_path) |