Files
ailai_image_point_diff/ailai_pc/angle.py
琉璃月光 c134abf749 first commit
2025-10-21 11:07:29 +08:00

145 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import os
import numpy as np
# 设置类别名称(必须与训练时一致)
CLASS_NAMES = ['ban', 'bag'] # ✅ 确保顺序正确,对应模型的 class_id
COLORS = [(0, 255, 0), (255, 0, 0)] # ban: 绿色, bag: 蓝色
def get_best_angles_per_class(image_path, weight_path, return_degree=False):
"""
输入:
image_path: 图像路径
weight_path: YOLO OBB 权重路径
return_degree: 是否返回角度(单位:度),否则为弧度
输出:
字典:{ class_name: best_angle 或 None }
"""
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像:{image_path}")
return {cls: None for cls in CLASS_NAMES}
model = YOLO(weight_path)
results = model(img, save=False, imgsz=640, conf=0.15, task='obb')
result = results[0]
boxes = result.obb
if boxes is None or len(boxes) == 0:
print("⚠️ 未检测到任何目标。")
return {cls: None for cls in CLASS_NAMES}
# 提取数据
xywhr = boxes.xywhr.cpu().numpy() # (N, 5) -> cx, cy, w, h, r (弧度)
confs = boxes.conf.cpu().numpy() # (N,)
class_ids = boxes.cls.cpu().numpy().astype(int) # (N,)
# 初始化结果字典
best_angles = {cls: None for cls in CLASS_NAMES}
# 对每个类别找置信度最高的框
for class_id, class_name in enumerate(CLASS_NAMES):
mask = (class_ids == class_id)
if not np.any(mask):
print(f"🟡 未检测到类别: {class_name}")
continue
# 找该类别中置信度最高的
idx_in_class = np.argmax(confs[mask])
global_idx = np.where(mask)[0][idx_in_class]
angle_rad = xywhr[global_idx][4]
best_angles[class_name] = np.degrees(angle_rad) if return_degree else angle_rad
return best_angles
def save_obb_visual(image_path, weight_path, save_path):
"""
输入:
image_path: 图像路径
weight_path: YOLO权重路径
save_path: 保存带标注图像路径
功能:
检测所有 OBB绘制框、类别名、旋转角度保存图片
"""
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像:{image_path}")
return
model = YOLO(weight_path)
results = model(img, save=False, imgsz=640, conf=0.05, task='obb')
result = results[0]
boxes = result.obb
if boxes is None or len(boxes) == 0:
print("⚠️ 未检测到任何目标。")
# 仍保存原图
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, img)
return
# 提取信息
xywhr = boxes.xywhr.cpu().numpy()
confs = boxes.conf.cpu().numpy()
class_ids = boxes.cls.cpu().numpy().astype(int)
# 绘制
annotated_img = img.copy()
for i in range(len(boxes)):
cx, cy, w, h, r = xywhr[i]
angle_deg = np.degrees(r)
class_id = class_ids[i]
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"cls{class_id}"
conf = confs[i]
color = COLORS[class_id % len(COLORS)] if class_id < len(CLASS_NAMES) else (128, 128, 128)
# 绘制旋转框
rect = ((cx, cy), (w, h), angle_deg)
box_pts = cv2.boxPoints(rect).astype(int)
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=color, thickness=2)
# 标注文本:类别 + 置信度 + 角度
text = f"{class_name} {conf:.2f} {angle_deg:.1f}°"
font_scale = 0.7
thickness = 2
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
# 文本背景
cv2.rectangle(annotated_img,
(box_pts[0][0], box_pts[0][1] - text_size[1] - 8),
(box_pts[0][0] + text_size[0], box_pts[0][1] + 2),
color, -1)
# 文本
cv2.putText(annotated_img, text,
(box_pts[0][0], box_pts[0][1] - 5),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
# 保存
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, annotated_img)
print(f"✅ 检测结果已保存至: {save_path}")
# ===============================
# 示例调用
# ===============================
if __name__ == "__main__":
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
image = r"/home/hx/yolo/output_masks/2.jpg"
save_path = "./inference_results/visualized_2.jpg"
# 获取每个类别的最佳角度(以度为单位)
angles_deg = get_best_angles_per_class(image, weight, return_degree=True)
print("\n🎯 各类别最佳旋转角度(度):")
for cls_name, angle in angles_deg.items():
if angle is not None:
print(f" {cls_name}: {angle:.2f}°")
else:
print(f" {cls_name}: 未检测到")
# 可视化所有检测结果
save_obb_visual(image, weight, save_path)