Files
琉璃月光 8b263167f8 更新
2025-12-11 08:37:09 +08:00

192 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
# ================== 配置参数 ==================
MODEL_PATH = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/obb300/weights/last.pt"
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb/val"
#IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb5/val"
LABEL_SOURCE_DIR = IMAGE_SOURCE_DIR # 假设标签和图像在同一目录
OUTPUT_DIR = "./inference_results"
VISUAL_DIR = os.path.join(OUTPUT_DIR, "visual_errors_gt5deg") # 保存误差 >5° 的可视化图
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(VISUAL_DIR, exist_ok=True)
# 加载模型
print("🔄 加载 YOLO OBB 模型...")
model = YOLO(MODEL_PATH)
print("✅ 模型加载完成")
# 获取图像列表
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print(f"❌ 错误:未找到图像文件")
exit(1)
print(f"📁 发现 {len(image_files)} 张图像待处理")
all_angle_errors = [] # 存储每张图的夹角误差(度)
# ================== 工具函数 ==================
def parse_obb_label_file(label_path, img_shape):
"""
解析 OBB 标签文件,并将归一化坐标转换为像素坐标
img_shape: (height, width) 用于去归一化
"""
boxes = []
h, w = img_shape[:2]
if not os.path.exists(label_path):
print(f"⚠️ 标签文件不存在: {label_path}")
return boxes
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) != 9:
print(f"⚠️ 标签行格式错误 (期望9列): {parts}")
continue
cls_id = int(parts[0])
coords = list(map(float, parts[1:]))
points = np.array(coords).reshape(4, 2)
points[:, 0] *= w # x * width
points[:, 1] *= h # y * height
boxes.append({'cls': cls_id, 'points': points})
return boxes
def compute_main_direction(points):
"""根据四个顶点计算旋转框的主方向(长边方向),返回 [0, π) 范围内的弧度值"""
edges = []
for i in range(4):
p1 = points[i]
p2 = points[(i + 1) % 4]
vec = p2 - p1
length = np.linalg.norm(vec)
if length > 1e-6:
edges.append((length, vec))
if not edges:
return 0.0
longest_edge = max(edges, key=lambda x: x[0])[1]
angle_rad = np.arctan2(longest_edge[1], longest_edge[0])
angle_rad = angle_rad % np.pi
return angle_rad
def compute_min_angle_between_two_dirs(dir1_rad, dir2_rad):
"""计算两个方向之间的最小夹角0 ~ 90°返回角度制"""
diff = abs(dir1_rad - dir2_rad)
min_diff_rad = min(diff, np.pi - diff)
return np.degrees(min_diff_rad)
def draw_boxes_on_image(image, pred_boxes=None, true_boxes=None):
"""在图像上绘制预测框(绿色)和真实框(红色)"""
img_vis = image.copy()
# 绘制真实框(红色)
if true_boxes is not None:
for box in true_boxes:
pts = np.int32(box['points']).reshape((-1, 1, 2))
cv2.polylines(img_vis, [pts], isClosed=True, color=(0, 0, 255), thickness=2)
# 绘制预测框(绿色)
if pred_boxes is not None:
for box in pred_boxes:
xyxyxyxy = box.xyxyxyxy.cpu().numpy()[0]
pts = xyxyxyxy.reshape(4, 2).astype(int)
pts = pts.reshape((-1, 1, 2))
cv2.polylines(img_vis, [pts], isClosed=True, color=(0, 255, 0), thickness=2)
return img_vis
# ================== 主循环 ==================
for img_filename in image_files:
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
label_path = os.path.join(LABEL_SOURCE_DIR, os.path.splitext(img_filename)[0] + ".txt")
print(f"\n🖼️ 处理: {img_filename}")
# 读图
img = cv2.imread(img_path)
if img is None:
print("❌ 无法读取图像")
continue
# 推理
results = model(img, imgsz=640, conf=0.15, verbose=False)
result = results[0]
pred_boxes = result.obb
# === 提取预测框主方向 ===
pred_dirs = []
if pred_boxes is not None and len(pred_boxes) >= 2:
for box in pred_boxes[:2]:
xywhr = box.xywhr.cpu().numpy()[0]
cx, cy, w, h, r_rad = xywhr
main_dir = r_rad if w >= h else r_rad + np.pi / 2
pred_dirs.append(main_dir % np.pi)
pred_angle = compute_min_angle_between_two_dirs(pred_dirs[0], pred_dirs[1])
else:
print("❌ 预测框不足两个")
continue
# === 提取真实框主方向 ===
true_boxes = parse_obb_label_file(label_path, img.shape)
if len(true_boxes) < 2:
print("❌ 标签框不足两个")
continue
true_dirs = []
for tb in true_boxes[:2]:
d = compute_main_direction(tb['points'])
true_dirs.append(d)
true_angle = compute_min_angle_between_two_dirs(true_dirs[0], true_dirs[1])
# === 计算夹角误差 ===
error_deg = abs(pred_angle - true_angle)
all_angle_errors.append(error_deg)
print(f" 🔹 预测夹角: {pred_angle:.2f}°")
print(f" 🔹 真实夹角: {true_angle:.2f}°")
print(f" 🔺 夹角误差: {error_deg:.2f}°")
# === 可视化误差 >5° 的情况 ===
if error_deg > 3:
print(f" 🎯 误差 >5°生成可视化图像...")
img_with_boxes = draw_boxes_on_image(img, pred_boxes=pred_boxes, true_boxes=true_boxes)
# 添加文字
cv2.putText(img_with_boxes, f"Error: {error_deg:.2f}°", (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
vis_output_path = os.path.join(VISUAL_DIR, f"error_{error_deg:.2f}deg_{img_filename}")
cv2.imwrite(vis_output_path, img_with_boxes)
print(f" ✅ 已保存可视化图像: {vis_output_path}")
# ================== 输出统计 ==================
print("\n" + "=" * 60)
print("📊 夹角误差统计(基于两框间最小夹角)")
print("=" * 60)
if all_angle_errors:
mean_error = np.mean(all_angle_errors)
std_error = np.std(all_angle_errors)
max_error = np.max(all_angle_errors)
min_error = np.min(all_angle_errors)
print(f"有效图像数: {len(all_angle_errors)}")
print(f"平均夹角误差: {mean_error:.2f}°")
print(f"标准差: {std_error:.2f}°")
print(f"最大误差: {max_error:.2f}°")
print(f"最小误差: {min_error:.2f}°")
else:
print("❌ 无有效数据用于统计")
print("=" * 60)
print("🎉 所有图像处理完成!")