Files
ailai_image_point_diff/ailai_pc/point_error_test.py
2025-12-28 00:12:46 +08:00

126 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
from ultralytics import YOLO
# ====================== 用户配置 ======================
#MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai2/weights/best.pt'
MODEL_PATH = 'point.pt'
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251214' # 验证集图片目录
LABEL_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251214' # 标签目录(与图片同名 .txt
OUTPUT_DIR = './output_images'
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 可视化函数 ======================
def draw_keypoints_on_image(image, kpts_xy, colors, label_prefix=''):
for j, (x, y) in enumerate(kpts_xy):
x, y = int(x), int(y)
cv2.circle(image, (x, y), 8, colors[j % len(colors)], -1)
cv2.putText(image, f'{label_prefix}{j+1}', (x + 10, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1, colors[j % len(colors)], 2)
return image
# ====================== 标签读取函数 ======================
def load_keypoints_from_label(label_path, img_shape):
"""
标签格式:
<class> xc yc w h x1 y1 v1 x2 y2 v2 x3 y3 v3 x4 y4 v4
共 17 项1 + 4 + 12
"""
if not os.path.exists(label_path):
print(f"⚠️ 找不到标签文件: {label_path}")
return None
H, W = img_shape[:2]
with open(label_path, 'r') as f:
line = f.readline().strip().split()
if len(line) < 17:
print(f"⚠️ 标签长度不足: {label_path} ({len(line)}项)")
return None
floats = [float(x) for x in line[5:]] # 跳过前5个class + bbox
coords = np.array(floats).reshape(-1, 3)[:, :2] # (4,2)
coords[:, 0] *= W
coords[:, 1] *= H
return coords
# ====================== 主程序 ======================
if __name__ == "__main__":
print("🚀 开始验证集关键点误差计算")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print("❌ 未找到图像文件")
exit(1)
total_errors = []
skipped = 0
colors_gt = [(0, 255, 0), (0, 200, 0), (0, 150, 0), (0, 100, 0)]
colors_pred = [(0, 0, 255)] * 4
for img_filename in image_files:
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
label_path = os.path.join(LABEL_DIR, os.path.splitext(img_filename)[0] + '.txt')
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
skipped += 1
continue
gt_kpts = load_keypoints_from_label(label_path, img.shape)
if gt_kpts is None or len(gt_kpts) < 4:
print(f"⚠️ 跳过 {img_filename}:标签点不足")
skipped += 1
continue
results = model(img, verbose=False)
if not results or results[0].keypoints is None or len(results[0].keypoints) == 0:
print(f"⚠️ {img_filename}: 无检测结果,跳过")
skipped += 1
continue
pred_kpts = results[0].keypoints.xy[0].cpu().numpy()
if pred_kpts.shape[0] != gt_kpts.shape[0]:
print(f"⚠️ {img_filename}: 点数不匹配 GT={len(gt_kpts)}, Pred={len(pred_kpts)},跳过")
skipped += 1
continue
# ====================== 只计算关键点 0 和 1 ======================
selected_idx = [0, 1] # 只算前两个点
errors = np.linalg.norm(pred_kpts[selected_idx] - gt_kpts[selected_idx], axis=1)
mean_error = np.mean(errors)
total_errors.append(mean_error)
print(f"📸 {img_filename}: 点0&1误差={np.round(errors, 2)} 像素, 平均误差={mean_error:.2f}px")
# 可视化
vis_img = img.copy()
vis_img = draw_keypoints_on_image(vis_img, gt_kpts, colors_gt, label_prefix='GT')
vis_img = draw_keypoints_on_image(vis_img, pred_kpts, colors_pred, label_prefix='P')
save_path = os.path.join(OUTPUT_DIR, f"compare_{img_filename}")
cv2.imwrite(save_path, vis_img)
# ====================== 结果统计 ======================
print("\n======================")
if total_errors:
print(f"🎯 有效样本数: {len(total_errors)}")
print(f"🚫 跳过样本数: {skipped}")
print(f"📈 平均关键点误差: {np.mean(total_errors):.2f} 像素")
else:
print(f"⚠️ 所有样本均被跳过(跳过 {skipped} 张)")
print("======================")