import cv2 import numpy as np import os from ultralytics import YOLO # ====================== 用户配置 ====================== #MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai2/weights/best.pt' MODEL_PATH = 'point.pt' IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251212' # 验证集图片目录 LABEL_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251212' # 标签目录(与图片同名 .txt) OUTPUT_DIR = './output_images' IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'} os.makedirs(OUTPUT_DIR, exist_ok=True) # ====================== 可视化函数 ====================== def draw_keypoints_on_image(image, kpts_xy, colors, label_prefix=''): for j, (x, y) in enumerate(kpts_xy): x, y = int(x), int(y) cv2.circle(image, (x, y), 8, colors[j % len(colors)], -1) cv2.putText(image, f'{label_prefix}{j+1}', (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, colors[j % len(colors)], 2) return image # ====================== 标签读取函数 ====================== def load_keypoints_from_label(label_path, img_shape): """ 标签格式: xc yc w h x1 y1 v1 x2 y2 v2 x3 y3 v3 x4 y4 v4 共 17 项:1 + 4 + 12 """ if not os.path.exists(label_path): print(f"⚠️ 找不到标签文件: {label_path}") return None H, W = img_shape[:2] with open(label_path, 'r') as f: line = f.readline().strip().split() if len(line) < 17: print(f"⚠️ 标签长度不足: {label_path} ({len(line)}项)") return None floats = [float(x) for x in line[5:]] # 跳过前5个(class + bbox) coords = np.array(floats).reshape(-1, 3)[:, :2] # (4,2) coords[:, 0] *= W coords[:, 1] *= H return coords # ====================== 主程序 ====================== if __name__ == "__main__": print("🚀 开始验证集关键点误差计算") model = YOLO(MODEL_PATH) print(f"✅ 模型加载完成: {MODEL_PATH}") image_files = [ f for f in os.listdir(IMAGE_SOURCE_DIR) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS ] if not image_files: print("❌ 未找到图像文件") exit(1) total_errors = [] skipped = 0 colors_gt = [(0, 255, 0), (0, 200, 0), (0, 150, 0), (0, 100, 0)] colors_pred = [(0, 0, 255)] * 4 for img_filename in image_files: img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename) label_path = os.path.join(LABEL_DIR, os.path.splitext(img_filename)[0] + '.txt') img = cv2.imread(img_path) if img is None: print(f"❌ 无法读取图像: {img_path}") skipped += 1 continue gt_kpts = load_keypoints_from_label(label_path, img.shape) if gt_kpts is None or len(gt_kpts) < 4: print(f"⚠️ 跳过 {img_filename}:标签点不足") skipped += 1 continue results = model(img, verbose=False) if not results or results[0].keypoints is None or len(results[0].keypoints) == 0: print(f"⚠️ {img_filename}: 无检测结果,跳过") skipped += 1 continue pred_kpts = results[0].keypoints.xy[0].cpu().numpy() if pred_kpts.shape[0] != gt_kpts.shape[0]: print(f"⚠️ {img_filename}: 点数不匹配 GT={len(gt_kpts)}, Pred={len(pred_kpts)},跳过") skipped += 1 continue # ====================== 只计算关键点 0 和 1 ====================== selected_idx = [0, 1] # 只算前两个点 errors = np.linalg.norm(pred_kpts[selected_idx] - gt_kpts[selected_idx], axis=1) mean_error = np.mean(errors) total_errors.append(mean_error) print(f"📸 {img_filename}: 点0&1误差={np.round(errors, 2)} 像素, 平均误差={mean_error:.2f}px") # 可视化 vis_img = img.copy() vis_img = draw_keypoints_on_image(vis_img, gt_kpts, colors_gt, label_prefix='GT') vis_img = draw_keypoints_on_image(vis_img, pred_kpts, colors_pred, label_prefix='P') save_path = os.path.join(OUTPUT_DIR, f"compare_{img_filename}") cv2.imwrite(save_path, vis_img) # ====================== 结果统计 ====================== print("\n======================") if total_errors: print(f"🎯 有效样本数: {len(total_errors)} 张") print(f"🚫 跳过样本数: {skipped} 张") print(f"📈 平均关键点误差: {np.mean(total_errors):.2f} 像素") else: print(f"⚠️ 所有样本均被跳过(跳过 {skipped} 张)") print("======================")