2025-11-03 16:10:50 +08:00
|
|
|
|
import cv2
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import os
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 用户配置 ======================
|
2025-11-25 21:06:20 +08:00
|
|
|
|
#MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai2/weights/best.pt'
|
2026-01-04 13:36:07 +08:00
|
|
|
|
MODEL_PATH = 'best.pt'
|
|
|
|
|
|
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226' # 验证集图片目录
|
|
|
|
|
|
LABEL_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226' # 标签目录(与图片同名 .txt)
|
|
|
|
|
|
OUTPUT_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226/output_images'
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 可视化函数 ======================
|
|
|
|
|
|
def draw_keypoints_on_image(image, kpts_xy, colors, label_prefix=''):
|
|
|
|
|
|
for j, (x, y) in enumerate(kpts_xy):
|
|
|
|
|
|
x, y = int(x), int(y)
|
|
|
|
|
|
cv2.circle(image, (x, y), 8, colors[j % len(colors)], -1)
|
|
|
|
|
|
cv2.putText(image, f'{label_prefix}{j+1}', (x + 10, y - 10),
|
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, colors[j % len(colors)], 2)
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 标签读取函数 ======================
|
|
|
|
|
|
def load_keypoints_from_label(label_path, img_shape):
|
|
|
|
|
|
"""
|
|
|
|
|
|
标签格式:
|
|
|
|
|
|
<class> xc yc w h x1 y1 v1 x2 y2 v2 x3 y3 v3 x4 y4 v4
|
|
|
|
|
|
共 17 项:1 + 4 + 12
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not os.path.exists(label_path):
|
|
|
|
|
|
print(f"⚠️ 找不到标签文件: {label_path}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
H, W = img_shape[:2]
|
|
|
|
|
|
with open(label_path, 'r') as f:
|
|
|
|
|
|
line = f.readline().strip().split()
|
|
|
|
|
|
|
|
|
|
|
|
if len(line) < 17:
|
|
|
|
|
|
print(f"⚠️ 标签长度不足: {label_path} ({len(line)}项)")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
floats = [float(x) for x in line[5:]] # 跳过前5个(class + bbox)
|
|
|
|
|
|
coords = np.array(floats).reshape(-1, 3)[:, :2] # (4,2)
|
|
|
|
|
|
coords[:, 0] *= W
|
|
|
|
|
|
coords[:, 1] *= H
|
|
|
|
|
|
return coords
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 主程序 ======================
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("🚀 开始验证集关键点误差计算")
|
|
|
|
|
|
|
|
|
|
|
|
model = YOLO(MODEL_PATH)
|
|
|
|
|
|
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
|
|
|
|
|
|
|
|
|
|
|
image_files = [
|
|
|
|
|
|
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
|
|
|
|
|
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if not image_files:
|
|
|
|
|
|
print("❌ 未找到图像文件")
|
|
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
total_errors = []
|
|
|
|
|
|
skipped = 0
|
|
|
|
|
|
colors_gt = [(0, 255, 0), (0, 200, 0), (0, 150, 0), (0, 100, 0)]
|
|
|
|
|
|
colors_pred = [(0, 0, 255)] * 4
|
|
|
|
|
|
|
|
|
|
|
|
for img_filename in image_files:
|
|
|
|
|
|
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
|
|
|
|
|
label_path = os.path.join(LABEL_DIR, os.path.splitext(img_filename)[0] + '.txt')
|
|
|
|
|
|
|
|
|
|
|
|
img = cv2.imread(img_path)
|
|
|
|
|
|
if img is None:
|
|
|
|
|
|
print(f"❌ 无法读取图像: {img_path}")
|
|
|
|
|
|
skipped += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
gt_kpts = load_keypoints_from_label(label_path, img.shape)
|
|
|
|
|
|
if gt_kpts is None or len(gt_kpts) < 4:
|
|
|
|
|
|
print(f"⚠️ 跳过 {img_filename}:标签点不足")
|
|
|
|
|
|
skipped += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
results = model(img, verbose=False)
|
|
|
|
|
|
if not results or results[0].keypoints is None or len(results[0].keypoints) == 0:
|
|
|
|
|
|
print(f"⚠️ {img_filename}: 无检测结果,跳过")
|
|
|
|
|
|
skipped += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
pred_kpts = results[0].keypoints.xy[0].cpu().numpy()
|
|
|
|
|
|
if pred_kpts.shape[0] != gt_kpts.shape[0]:
|
|
|
|
|
|
print(f"⚠️ {img_filename}: 点数不匹配 GT={len(gt_kpts)}, Pred={len(pred_kpts)},跳过")
|
|
|
|
|
|
skipped += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
2025-11-19 18:11:22 +08:00
|
|
|
|
# ====================== 只计算关键点 0 和 1 ======================
|
|
|
|
|
|
selected_idx = [0, 1] # 只算前两个点
|
|
|
|
|
|
|
|
|
|
|
|
errors = np.linalg.norm(pred_kpts[selected_idx] - gt_kpts[selected_idx], axis=1)
|
2025-11-03 16:10:50 +08:00
|
|
|
|
mean_error = np.mean(errors)
|
|
|
|
|
|
total_errors.append(mean_error)
|
|
|
|
|
|
|
2025-11-19 18:11:22 +08:00
|
|
|
|
print(f"📸 {img_filename}: 点0&1误差={np.round(errors, 2)} 像素, 平均误差={mean_error:.2f}px")
|
2025-11-03 16:10:50 +08:00
|
|
|
|
|
|
|
|
|
|
# 可视化
|
|
|
|
|
|
vis_img = img.copy()
|
|
|
|
|
|
vis_img = draw_keypoints_on_image(vis_img, gt_kpts, colors_gt, label_prefix='GT')
|
|
|
|
|
|
vis_img = draw_keypoints_on_image(vis_img, pred_kpts, colors_pred, label_prefix='P')
|
|
|
|
|
|
|
|
|
|
|
|
save_path = os.path.join(OUTPUT_DIR, f"compare_{img_filename}")
|
|
|
|
|
|
cv2.imwrite(save_path, vis_img)
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 结果统计 ======================
|
|
|
|
|
|
print("\n======================")
|
|
|
|
|
|
if total_errors:
|
|
|
|
|
|
print(f"🎯 有效样本数: {len(total_errors)} 张")
|
|
|
|
|
|
print(f"🚫 跳过样本数: {skipped} 张")
|
|
|
|
|
|
print(f"📈 平均关键点误差: {np.mean(total_errors):.2f} 像素")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"⚠️ 所有样本均被跳过(跳过 {skipped} 张)")
|
|
|
|
|
|
print("======================")
|