Files
zjsh_yolov11/angle_base_obb/yanzheng_move.py
琉璃月光 8b263167f8 更新
2025-12-11 08:37:09 +08:00

184 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from ultralytics import YOLO
import shutil
# ================== 配置参数 ==================
MODEL_PATH = r"/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_obb_new/weights/best.pt"
IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/zjdata16"
# IMAGE_SOURCE_DIR = r"/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb5/val"
LABEL_SOURCE_DIR = IMAGE_SOURCE_DIR # 标签与图像同目录
TEST_OUTPUT_DIR = os.path.join(IMAGE_SOURCE_DIR, "test") # 错误样本移动到此目录
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
# 创建 test 目录
os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)
# 加载模型
print("🔄 加载 YOLO OBB 模型...")
model = YOLO(MODEL_PATH)
print("✅ 模型加载完成")
# 获取图像列表
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print(f"❌ 错误:未找到图像文件")
exit(1)
print(f"📁 发现 {len(image_files)} 张图像待处理")
all_angle_errors = [] # 存储每张图的夹角误差(度)
# ================== 工具函数 ==================
def parse_obb_label_file(label_path, img_shape):
"""解析 OBB 标签文件,并将归一化坐标转换为像素坐标"""
boxes = []
h, w = img_shape[:2]
if not os.path.exists(label_path):
print(f"⚠️ 标签文件不存在: {label_path}")
return boxes
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) != 9:
print(f"⚠️ 标签行格式错误 (期望9列): {parts}")
continue
cls_id = int(parts[0])
coords = list(map(float, parts[1:]))
points = np.array(coords).reshape(4, 2)
points[:, 0] *= w # x * width
points[:, 1] *= h # y * height
boxes.append({'cls': cls_id, 'points': points})
return boxes
def compute_main_direction(points):
"""根据四个顶点计算旋转框的主方向(长边方向),返回 [0, π) 范围内的弧度值"""
edges = []
for i in range(4):
p1 = points[i]
p2 = points[(i + 1) % 4]
vec = p2 - p1
length = np.linalg.norm(vec)
if length > 1e-6:
edges.append((length, vec))
if not edges:
return 0.0
longest_edge = max(edges, key=lambda x: x[0])[1]
angle_rad = np.arctan2(longest_edge[1], longest_edge[0])
angle_rad = angle_rad % np.pi
return angle_rad
def compute_min_angle_between_two_dirs(dir1_rad, dir2_rad):
"""计算两个方向之间的最小夹角0 ~ 90°返回角度制"""
diff = abs(dir1_rad - dir2_rad)
min_diff_rad = min(diff, np.pi - diff)
return np.degrees(min_diff_rad)
# ================== 主循环 ==================
for img_filename in image_files:
stem = os.path.splitext(img_filename)[0]
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
label_path = os.path.join(LABEL_SOURCE_DIR, stem + ".txt")
print(f"\n🖼️ 处理: {img_filename}")
# 读图
img = cv2.imread(img_path)
if img is None:
print("❌ 无法读取图像")
continue
# 推理
results = model(img, imgsz=640, conf=0.15, verbose=False)
result = results[0]
pred_boxes = result.obb
# === 提取预测框主方向(最多前两个)===
pred_dirs = []
if pred_boxes is not None and len(pred_boxes) >= 2:
for box in pred_boxes[:2]:
xywhr = box.xywhr.cpu().numpy()[0]
cx, cy, w, h, r_rad = xywhr
main_dir = r_rad if w >= h else r_rad + np.pi / 2
pred_dirs.append(main_dir % np.pi)
pred_angle = compute_min_angle_between_two_dirs(pred_dirs[0], pred_dirs[1])
else:
print("❌ 预测框不足两个")
continue
# === 提取真实框主方向(最多前两个)===
true_boxes = parse_obb_label_file(label_path, img.shape)
if len(true_boxes) < 2:
print("❌ 标签框不足两个")
continue
true_dirs = []
for tb in true_boxes[:2]:
d = compute_main_direction(tb['points'])
true_dirs.append(d)
true_angle = compute_min_angle_between_two_dirs(true_dirs[0], true_dirs[1])
# === 计算夹角误差 ===
error_deg = abs(pred_angle - true_angle)
all_angle_errors.append(error_deg)
print(f" 🔹 预测夹角: {pred_angle:.2f}°")
print(f" 🔹 真实夹角: {true_angle:.2f}°")
print(f" 🔺 夹角误差: {error_deg:.2f}°")
# === 如果误差 > 1.5°,移动原图和原 txt 到 test/ ===
if error_deg > 1.5:
print(f" 🚩 误差 >1.5°,移动原文件到 test/ ...")
# 构建目标路径
img_dst = os.path.join(TEST_OUTPUT_DIR, img_filename)
txt_dst = os.path.join(TEST_OUTPUT_DIR, stem + ".txt")
try:
# 移动图像
shutil.move(img_path, img_dst)
print(f" ✅ 移动图像: {img_path}{img_dst}")
# 移动标签(如果存在)
if os.path.exists(label_path):
shutil.move(label_path, txt_dst)
print(f" ✅ 移动标签: {label_path}{txt_dst}")
else:
print(f" ⚠️ 标签不存在,仅移动图像")
except Exception as e:
print(f" ❌ 移动失败: {e}")
# ================== 输出统计 ==================
print("\n" + "=" * 60)
print("📊 夹角误差统计(基于两框间最小夹角)")
print("=" * 60)
if all_angle_errors:
mean_error = np.mean(all_angle_errors)
std_error = np.std(all_angle_errors)
max_error = np.max(all_angle_errors)
min_error = np.min(all_angle_errors)
print(f"有效图像数: {len(all_angle_errors)}")
print(f"平均夹角误差: {mean_error:.2f}°")
print(f"标准差: {std_error:.2f}°")
print(f"最大误差: {max_error:.2f}°")
print(f"最小误差: {min_error:.2f}°")
else:
print("❌ 无有效数据用于统计")
print("=" * 60)
print("🎉 所有图像处理完成!")
print(f"⚠️ 误差 >1.5° 的样本已移至: {TEST_OUTPUT_DIR}")