Files
zjsh_yolov11/angle_base_obb/error_E.py
2025-09-11 20:44:35 +08:00

154 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from ultralytics import YOLO
# ================== 配置参数 ==================
MODEL_PATH = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb4/weights/best.pt"
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb2/test"
LABEL_SOURCE_DIR = IMAGE_SOURCE_DIR # 假设标签和图像在同一目录
OUTPUT_DIR = "./inference_results"
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 加载模型
print("🔄 加载 YOLO OBB 模型...")
model = YOLO(MODEL_PATH)
print("✅ 模型加载完成")
# 获取图像列表
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print(f"❌ 错误:未找到图像文件")
exit(1)
print(f"📁 发现 {len(image_files)} 张图像待处理")
all_angle_errors = [] # 存储每张图的夹角误差(度)
# ================== 工具函数 ==================
def parse_obb_label_file(label_path):
"""解析 OBB 标签文件,返回 [{'cls': int, 'points': (4,2)}]"""
boxes = []
if not os.path.exists(label_path):
return boxes
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) != 9:
continue
cls_id = int(parts[0])
coords = list(map(float, parts[1:]))
points = np.array(coords).reshape(4, 2)
boxes.append({'cls': cls_id, 'points': points})
return boxes
def compute_main_direction(points):
"""
根据四个顶点计算旋转框的主方向(长边方向),
返回 [0, π) 范围内的弧度值。
"""
edges = []
for i in range(4):
p1 = points[i]
p2 = points[(i + 1) % 4]
vec = p2 - p1
length = np.linalg.norm(vec)
if length > 1e-6:
edges.append((length, vec))
if not edges:
return 0.0
# 找最长边
longest_edge = max(edges, key=lambda x: x[0])[1]
angle_rad = np.arctan2(longest_edge[1], longest_edge[0])
# 归一化到 [0, π)
angle_rad = angle_rad % np.pi
return angle_rad
def compute_min_angle_between_two_dirs(dir1_rad, dir2_rad):
"""计算两个方向之间的最小夹角0 ~ 90°返回角度制"""
diff = abs(dir1_rad - dir2_rad)
min_diff_rad = min(diff, np.pi - diff)
return np.degrees(min_diff_rad) # 返回 0~90°
# ================== 主循环 ==================
for img_filename in image_files:
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
label_path = os.path.join(LABEL_SOURCE_DIR, os.path.splitext(img_filename)[0] + ".txt")
print(f"\n🖼️ 处理: {img_filename}")
# 读图
img = cv2.imread(img_path)
if img is None:
print("❌ 无法读取图像")
continue
# 推理
results = model(img, imgsz=640, conf=0.15, verbose=False)
result = results[0]
pred_boxes = result.obb
# === 提取预测框主方向 ===
pred_dirs = []
if pred_boxes is not None and len(pred_boxes) >= 2:
for box in pred_boxes[:2]: # 只取前两个
xywhr = box.xywhr.cpu().numpy()[0]
cx, cy, w, h, r_rad = xywhr
main_dir = r_rad if w >= h else r_rad + np.pi / 2
pred_dirs.append(main_dir % np.pi)
pred_angle = compute_min_angle_between_two_dirs(pred_dirs[0], pred_dirs[1])
else:
print("❌ 预测框不足两个")
continue
# === 提取真实框主方向 ===
true_boxes = parse_obb_label_file(label_path)
if len(true_boxes) < 2:
print("❌ 标签框不足两个")
continue
true_dirs = []
for tb in true_boxes[:2]: # 取前两个
d = compute_main_direction(tb['points'])
true_dirs.append(d)
true_angle = compute_min_angle_between_two_dirs(true_dirs[0], true_dirs[1])
# === 计算夹角误差 ===
error_deg = abs(pred_angle - true_angle)
all_angle_errors.append(error_deg)
print(f" 🔹 预测夹角: {pred_angle:.2f}°")
print(f" 🔹 真实夹角: {true_angle:.2f}°")
print(f" 🔺 夹角误差: {error_deg:.2f}°")
# ================== 输出统计 ==================
print("\n" + "=" * 60)
print("📊 夹角误差统计(基于两框间最小夹角)")
print("=" * 60)
if all_angle_errors:
mean_error = np.mean(all_angle_errors)
std_error = np.std(all_angle_errors)
max_error = np.max(all_angle_errors)
min_error = np.min(all_angle_errors)
print(f"有效图像数: {len(all_angle_errors)}")
print(f"平均夹角误差: {mean_error:.2f}°")
print(f"标准差: {std_error:.2f}°")
print(f"最大误差: {max_error:.2f}°")
print(f"最小误差: {min_error:.2f}°")
else:
print("❌ 无有效数据用于统计")
print("=" * 60)
print("🎉 所有图像处理完成!")