Files
ailai_image_point_diff/ailai_pc/point_error_test.py

126 lines
4.7 KiB
Python
Raw Normal View History

2025-11-03 16:10:50 +08:00
import cv2
import numpy as np
import os
from ultralytics import YOLO
# ====================== 用户配置 ======================
2025-11-25 21:06:20 +08:00
#MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai2/weights/best.pt'
2026-01-08 17:25:14 +08:00
MODEL_PATH = 'pointn.pt'
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/test' # 验证集图片目录
LABEL_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/test' # 标签目录(与图片同名 .txt
OUTPUT_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/test/output_images'
2025-11-03 16:10:50 +08:00
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 可视化函数 ======================
def draw_keypoints_on_image(image, kpts_xy, colors, label_prefix=''):
for j, (x, y) in enumerate(kpts_xy):
x, y = int(x), int(y)
cv2.circle(image, (x, y), 8, colors[j % len(colors)], -1)
cv2.putText(image, f'{label_prefix}{j+1}', (x + 10, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1, colors[j % len(colors)], 2)
return image
# ====================== 标签读取函数 ======================
def load_keypoints_from_label(label_path, img_shape):
"""
标签格式:
<class> xc yc w h x1 y1 v1 x2 y2 v2 x3 y3 v3 x4 y4 v4
17 1 + 4 + 12
"""
if not os.path.exists(label_path):
print(f"⚠️ 找不到标签文件: {label_path}")
return None
H, W = img_shape[:2]
with open(label_path, 'r') as f:
line = f.readline().strip().split()
if len(line) < 17:
print(f"⚠️ 标签长度不足: {label_path} ({len(line)}项)")
return None
floats = [float(x) for x in line[5:]] # 跳过前5个class + bbox
coords = np.array(floats).reshape(-1, 3)[:, :2] # (4,2)
coords[:, 0] *= W
coords[:, 1] *= H
return coords
# ====================== 主程序 ======================
if __name__ == "__main__":
print("🚀 开始验证集关键点误差计算")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print("❌ 未找到图像文件")
exit(1)
total_errors = []
skipped = 0
colors_gt = [(0, 255, 0), (0, 200, 0), (0, 150, 0), (0, 100, 0)]
colors_pred = [(0, 0, 255)] * 4
for img_filename in image_files:
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
label_path = os.path.join(LABEL_DIR, os.path.splitext(img_filename)[0] + '.txt')
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
skipped += 1
continue
gt_kpts = load_keypoints_from_label(label_path, img.shape)
if gt_kpts is None or len(gt_kpts) < 4:
print(f"⚠️ 跳过 {img_filename}:标签点不足")
skipped += 1
continue
results = model(img, verbose=False)
if not results or results[0].keypoints is None or len(results[0].keypoints) == 0:
print(f"⚠️ {img_filename}: 无检测结果,跳过")
skipped += 1
continue
pred_kpts = results[0].keypoints.xy[0].cpu().numpy()
if pred_kpts.shape[0] != gt_kpts.shape[0]:
print(f"⚠️ {img_filename}: 点数不匹配 GT={len(gt_kpts)}, Pred={len(pred_kpts)},跳过")
skipped += 1
continue
2025-11-19 18:11:22 +08:00
# ====================== 只计算关键点 0 和 1 ======================
selected_idx = [0, 1] # 只算前两个点
errors = np.linalg.norm(pred_kpts[selected_idx] - gt_kpts[selected_idx], axis=1)
2025-11-03 16:10:50 +08:00
mean_error = np.mean(errors)
total_errors.append(mean_error)
2025-11-19 18:11:22 +08:00
print(f"📸 {img_filename}: 点0&1误差={np.round(errors, 2)} 像素, 平均误差={mean_error:.2f}px")
2025-11-03 16:10:50 +08:00
# 可视化
vis_img = img.copy()
vis_img = draw_keypoints_on_image(vis_img, gt_kpts, colors_gt, label_prefix='GT')
vis_img = draw_keypoints_on_image(vis_img, pred_kpts, colors_pred, label_prefix='P')
save_path = os.path.join(OUTPUT_DIR, f"compare_{img_filename}")
cv2.imwrite(save_path, vis_img)
# ====================== 结果统计 ======================
print("\n======================")
if total_errors:
print(f"🎯 有效样本数: {len(total_errors)}")
print(f"🚫 跳过样本数: {skipped}")
print(f"📈 平均关键点误差: {np.mean(total_errors):.2f} 像素")
else:
print(f"⚠️ 所有样本均被跳过(跳过 {skipped} 张)")
print("======================")