first commit
This commit is contained in:
145
ailai_pc/angle.py
Normal file
145
ailai_pc/angle.py
Normal file
@ -0,0 +1,145 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# 设置类别名称(必须与训练时一致)
|
||||
CLASS_NAMES = ['ban', 'bag'] # ✅ 确保顺序正确,对应模型的 class_id
|
||||
COLORS = [(0, 255, 0), (255, 0, 0)] # ban: 绿色, bag: 蓝色
|
||||
|
||||
|
||||
def get_best_angles_per_class(image_path, weight_path, return_degree=False):
|
||||
"""
|
||||
输入:
|
||||
image_path: 图像路径
|
||||
weight_path: YOLO OBB 权重路径
|
||||
return_degree: 是否返回角度(单位:度),否则为弧度
|
||||
输出:
|
||||
字典:{ class_name: best_angle 或 None }
|
||||
"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像:{image_path}")
|
||||
return {cls: None for cls in CLASS_NAMES}
|
||||
|
||||
model = YOLO(weight_path)
|
||||
results = model(img, save=False, imgsz=640, conf=0.15, task='obb')
|
||||
result = results[0]
|
||||
|
||||
boxes = result.obb
|
||||
if boxes is None or len(boxes) == 0:
|
||||
print("⚠️ 未检测到任何目标。")
|
||||
return {cls: None for cls in CLASS_NAMES}
|
||||
|
||||
# 提取数据
|
||||
xywhr = boxes.xywhr.cpu().numpy() # (N, 5) -> cx, cy, w, h, r (弧度)
|
||||
confs = boxes.conf.cpu().numpy() # (N,)
|
||||
class_ids = boxes.cls.cpu().numpy().astype(int) # (N,)
|
||||
|
||||
# 初始化结果字典
|
||||
best_angles = {cls: None for cls in CLASS_NAMES}
|
||||
|
||||
# 对每个类别找置信度最高的框
|
||||
for class_id, class_name in enumerate(CLASS_NAMES):
|
||||
mask = (class_ids == class_id)
|
||||
if not np.any(mask):
|
||||
print(f"🟡 未检测到类别: {class_name}")
|
||||
continue
|
||||
|
||||
# 找该类别中置信度最高的
|
||||
idx_in_class = np.argmax(confs[mask])
|
||||
global_idx = np.where(mask)[0][idx_in_class]
|
||||
angle_rad = xywhr[global_idx][4]
|
||||
|
||||
best_angles[class_name] = np.degrees(angle_rad) if return_degree else angle_rad
|
||||
|
||||
return best_angles
|
||||
|
||||
|
||||
def save_obb_visual(image_path, weight_path, save_path):
|
||||
"""
|
||||
输入:
|
||||
image_path: 图像路径
|
||||
weight_path: YOLO权重路径
|
||||
save_path: 保存带标注图像路径
|
||||
功能:
|
||||
检测所有 OBB,绘制框、类别名、旋转角度,保存图片
|
||||
"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"❌ 无法读取图像:{image_path}")
|
||||
return
|
||||
|
||||
model = YOLO(weight_path)
|
||||
results = model(img, save=False, imgsz=640, conf=0.05, task='obb')
|
||||
result = results[0]
|
||||
|
||||
boxes = result.obb
|
||||
if boxes is None or len(boxes) == 0:
|
||||
print("⚠️ 未检测到任何目标。")
|
||||
# 仍保存原图
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, img)
|
||||
return
|
||||
|
||||
# 提取信息
|
||||
xywhr = boxes.xywhr.cpu().numpy()
|
||||
confs = boxes.conf.cpu().numpy()
|
||||
class_ids = boxes.cls.cpu().numpy().astype(int)
|
||||
|
||||
# 绘制
|
||||
annotated_img = img.copy()
|
||||
for i in range(len(boxes)):
|
||||
cx, cy, w, h, r = xywhr[i]
|
||||
angle_deg = np.degrees(r)
|
||||
class_id = class_ids[i]
|
||||
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"cls{class_id}"
|
||||
conf = confs[i]
|
||||
color = COLORS[class_id % len(COLORS)] if class_id < len(CLASS_NAMES) else (128, 128, 128)
|
||||
|
||||
# 绘制旋转框
|
||||
rect = ((cx, cy), (w, h), angle_deg)
|
||||
box_pts = cv2.boxPoints(rect).astype(int)
|
||||
cv2.polylines(annotated_img, [box_pts], isClosed=True, color=color, thickness=2)
|
||||
|
||||
# 标注文本:类别 + 置信度 + 角度
|
||||
text = f"{class_name} {conf:.2f} {angle_deg:.1f}°"
|
||||
font_scale = 0.7
|
||||
thickness = 2
|
||||
text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||||
|
||||
# 文本背景
|
||||
cv2.rectangle(annotated_img,
|
||||
(box_pts[0][0], box_pts[0][1] - text_size[1] - 8),
|
||||
(box_pts[0][0] + text_size[0], box_pts[0][1] + 2),
|
||||
color, -1)
|
||||
# 文本
|
||||
cv2.putText(annotated_img, text,
|
||||
(box_pts[0][0], box_pts[0][1] - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
|
||||
|
||||
# 保存
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
cv2.imwrite(save_path, annotated_img)
|
||||
print(f"✅ 检测结果已保存至: {save_path}")
|
||||
|
||||
|
||||
# ===============================
|
||||
# 示例调用
|
||||
# ===============================
|
||||
if __name__ == "__main__":
|
||||
weight = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb3/weights/best.pt"
|
||||
image = r"/home/hx/yolo/output_masks/2.jpg"
|
||||
save_path = "./inference_results/visualized_2.jpg"
|
||||
|
||||
# 获取每个类别的最佳角度(以度为单位)
|
||||
angles_deg = get_best_angles_per_class(image, weight, return_degree=True)
|
||||
print("\n🎯 各类别最佳旋转角度(度):")
|
||||
for cls_name, angle in angles_deg.items():
|
||||
if angle is not None:
|
||||
print(f" {cls_name}: {angle:.2f}°")
|
||||
else:
|
||||
print(f" {cls_name}: 未检测到")
|
||||
|
||||
# 可视化所有检测结果
|
||||
save_obb_visual(image, weight, save_path)
|
||||
Reference in New Issue
Block a user