145 lines
5.0 KiB
Python
145 lines
5.0 KiB
Python
|
|
from ultralytics import YOLO
|
|||
|
|
import cv2
|
|||
|
|
import os
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
# 设置类别名称(必须与训练时一致)
|
|||
|
|
CLASS_NAMES = ['ban', 'bag'] # ✅ 确保顺序正确,对应模型的 class_id
|
|||
|
|
COLORS = [(0, 255, 0), (255, 0, 0)] # ban: 绿色, bag: 蓝色
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_best_angles_per_class(image_path, weight_path, return_degree=False):
|
|||
|
|
"""
|
|||
|
|
输入:
|
|||
|
|
image_path: 图像路径
|
|||
|
|
weight_path: YOLO OBB 权重路径
|
|||
|
|
return_degree: 是否返回角度(单位:度),否则为弧度
|
|||
|
|
输出:
|
|||
|
|
字典:{ class_name: best_angle 或 None }
|
|||
|
|
"""
|
|||
|
|
img = cv2.imread(image_path)
|
|||
|
|
if img is None:
|
|||
|
|
print(f"❌ 无法读取图像:{image_path}")
|
|||
|
|
return {cls: None for cls in CLASS_NAMES}
|
|||
|
|
|
|||
|
|
model = YOLO(weight_path)
|
|||
|
|
results = model(img, save=False, imgsz=640, conf=0.15, task='obb')
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
boxes = result.obb
|
|||
|
|
if boxes is None or len(boxes) == 0:
|
|||
|
|
print("⚠️ 未检测到任何目标。")
|
|||
|
|
return {cls: None for cls in CLASS_NAMES}
|
|||
|
|
|
|||
|
|
# 提取数据
|
|||
|
|
xywhr = boxes.xywhr.cpu().numpy() # (N, 5) -> cx, cy, w, h, r (弧度)
|
|||
|
|
confs = boxes.conf.cpu().numpy() # (N,)
|
|||
|
|
class_ids = boxes.cls.cpu().numpy().astype(int) # (N,)
|
|||
|
|
|
|||
|
|
# 初始化结果字典
|
|||
|
|
best_angles = {cls: None for cls in CLASS_NAMES}
|
|||
|
|
|
|||
|
|
# 对每个类别找置信度最高的框
|
|||
|
|
for class_id, class_name in enumerate(CLASS_NAMES):
|
|||
|
|
mask = (class_ids == class_id)
|
|||
|
|
if not np.any(mask):
|
|||
|
|
print(f"🟡 未检测到类别: {class_name}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 找该类别中置信度最高的
|
|||
|
|
idx_in_class = np.argmax(confs[mask])
|
|||
|
|
global_idx = np.where(mask)[0][idx_in_class]
|
|||
|
|
angle_rad = xywhr[global_idx][4]
|
|||
|
|
|
|||
|
|
best_angles[class_name] = np.degrees(angle_rad) if return_degree else angle_rad
|
|||
|
|
|
|||
|
|
return best_angles
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_obb_visual(image_path, weight_path, save_path):
|
|||
|
|
"""
|
|||
|
|
输入:
|
|||
|
|
image_path: 图像路径
|
|||
|
|
weight_path: YOLO权重路径
|
|||
|
|
save_path: 保存带标注图像路径
|
|||
|
|
功能:
|
|||
|
|
检测所有 OBB,绘制框、类别名、旋转角度,保存图片
|
|||
|
|
"""
|
|||
|
|
img = cv2.imread(image_path)
|
|||
|
|
if img is None:
|
|||
|
|
print(f"❌ 无法读取图像:{image_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
model = YOLO(weight_path)
|
|||
|
|
results = model(img, save=False, imgsz=640, conf=0.15, task='obb')
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
boxes = result.obb
|
|||
|
|
if boxes is None or len(boxes) == 0:
|
|||
|
|
print("⚠️ 未检测到任何目标。")
|
|||
|
|
# 仍保存原图
|
|||
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|||
|
|
cv2.imwrite(save_path, img)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 提取信息
|
|||
|
|
xywhr = boxes.xywhr.cpu().numpy()
|
|||
|
|
confs = boxes.conf.cpu().numpy()
|
|||
|
|
class_ids = boxes.cls.cpu().numpy().astype(int)
|
|||
|
|
|
|||
|
|
# 绘制
|
|||
|
|
annotated_img = img.copy()
|
|||
|
|
for i in range(len(boxes)):
|
|||
|
|
cx, cy, w, h, r = xywhr[i]
|
|||
|
|
angle_deg = np.degrees(r)
|
|||
|
|
class_id = class_ids[i]
|
|||
|
|
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"cls{class_id}"
|
|||
|
|
conf = confs[i]
|
|||
|
|
color = COLORS[class_id % len(COLORS)] if class_id < len(CLASS_NAMES) else (128, 128, 128)
|
|||
|
|
|
|||
|
|
# 绘制旋转框
|
|||
|
|
rect = ((cx, cy), (w, h), angle_deg)
|
|||
|
|
box_pts = cv2.boxPoints(rect).astype(int)
|
|||
|
|
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=color, thickness=2)
|
|||
|
|
|
|||
|
|
# 标注文本:类别 + 置信度 + 角度
|
|||
|
|
text = f"{class_name} {conf:.2f} {angle_deg:.1f}°"
|
|||
|
|
font_scale = 0.7
|
|||
|
|
thickness = 2
|
|||
|
|
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
|||
|
|
|
|||
|
|
# 文本背景
|
|||
|
|
cv2.rectangle(annotated_img,
|
|||
|
|
(box_pts[0][0], box_pts[0][1] - text_size[1] - 8),
|
|||
|
|
(box_pts[0][0] + text_size[0], box_pts[0][1] + 2),
|
|||
|
|
color, -1)
|
|||
|
|
# 文本
|
|||
|
|
cv2.putText(annotated_img, text,
|
|||
|
|
(box_pts[0][0], box_pts[0][1] - 5),
|
|||
|
|
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
|
|||
|
|
|
|||
|
|
# 保存
|
|||
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|||
|
|
cv2.imwrite(save_path, annotated_img)
|
|||
|
|
print(f"✅ 检测结果已保存至: {save_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===============================
|
|||
|
|
# 示例调用
|
|||
|
|
# ===============================
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb_ailai/weights/best.pt"
|
|||
|
|
image = r"/home/hx/yolo/ailai_obb/camera01/1.jpg"
|
|||
|
|
save_path = "./inference_results/visualized_2.jpg"
|
|||
|
|
|
|||
|
|
# 获取每个类别的最佳角度(以度为单位)
|
|||
|
|
angles_deg = get_best_angles_per_class(image, weight, return_degree=True)
|
|||
|
|
print("\n🎯 各类别最佳旋转角度(度):")
|
|||
|
|
for cls_name, angle in angles_deg.items():
|
|||
|
|
if angle is not None:
|
|||
|
|
print(f" {cls_name}: {angle:.2f}°")
|
|||
|
|
else:
|
|||
|
|
print(f" {cls_name}: 未检测到")
|
|||
|
|
|
|||
|
|
# 可视化所有检测结果
|
|||
|
|
save_obb_visual(image, weight, save_path)
|