126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
import cv2
|
||
import numpy as np
|
||
import os
|
||
from ultralytics import YOLO
|
||
|
||
# ====================== 用户配置 ======================
|
||
#MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai2/weights/best.pt'
|
||
MODEL_PATH = 'pointn.pt'
|
||
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train' # 验证集图片目录
|
||
LABEL_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train' # 标签目录(与图片同名 .txt)
|
||
OUTPUT_DIR = './output_images'
|
||
|
||
|
||
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# ====================== 可视化函数 ======================
|
||
def draw_keypoints_on_image(image, kpts_xy, colors, label_prefix=''):
|
||
for j, (x, y) in enumerate(kpts_xy):
|
||
x, y = int(x), int(y)
|
||
cv2.circle(image, (x, y), 8, colors[j % len(colors)], -1)
|
||
cv2.putText(image, f'{label_prefix}{j+1}', (x + 10, y - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1, colors[j % len(colors)], 2)
|
||
return image
|
||
|
||
# ====================== 标签读取函数 ======================
|
||
def load_keypoints_from_label(label_path, img_shape):
|
||
"""
|
||
标签格式:
|
||
<class> xc yc w h x1 y1 v1 x2 y2 v2 x3 y3 v3 x4 y4 v4
|
||
共 17 项:1 + 4 + 12
|
||
"""
|
||
if not os.path.exists(label_path):
|
||
print(f"⚠️ 找不到标签文件: {label_path}")
|
||
return None
|
||
|
||
H, W = img_shape[:2]
|
||
with open(label_path, 'r') as f:
|
||
line = f.readline().strip().split()
|
||
|
||
if len(line) < 17:
|
||
print(f"⚠️ 标签长度不足: {label_path} ({len(line)}项)")
|
||
return None
|
||
|
||
floats = [float(x) for x in line[5:]] # 跳过前5个(class + bbox)
|
||
coords = np.array(floats).reshape(-1, 3)[:, :2] # (4,2)
|
||
coords[:, 0] *= W
|
||
coords[:, 1] *= H
|
||
return coords
|
||
|
||
# ====================== 主程序 ======================
|
||
if __name__ == "__main__":
|
||
print("🚀 开始验证集关键点误差计算")
|
||
|
||
model = YOLO(MODEL_PATH)
|
||
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
||
|
||
image_files = [
|
||
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
||
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
||
]
|
||
|
||
if not image_files:
|
||
print("❌ 未找到图像文件")
|
||
exit(1)
|
||
|
||
total_errors = []
|
||
skipped = 0
|
||
colors_gt = [(0, 255, 0), (0, 200, 0), (0, 150, 0), (0, 100, 0)]
|
||
colors_pred = [(0, 0, 255)] * 4
|
||
|
||
for img_filename in image_files:
|
||
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
||
label_path = os.path.join(LABEL_DIR, os.path.splitext(img_filename)[0] + '.txt')
|
||
|
||
img = cv2.imread(img_path)
|
||
if img is None:
|
||
print(f"❌ 无法读取图像: {img_path}")
|
||
skipped += 1
|
||
continue
|
||
|
||
gt_kpts = load_keypoints_from_label(label_path, img.shape)
|
||
if gt_kpts is None or len(gt_kpts) < 4:
|
||
print(f"⚠️ 跳过 {img_filename}:标签点不足")
|
||
skipped += 1
|
||
continue
|
||
|
||
results = model(img, verbose=False)
|
||
if not results or results[0].keypoints is None or len(results[0].keypoints) == 0:
|
||
print(f"⚠️ {img_filename}: 无检测结果,跳过")
|
||
skipped += 1
|
||
continue
|
||
|
||
pred_kpts = results[0].keypoints.xy[0].cpu().numpy()
|
||
if pred_kpts.shape[0] != gt_kpts.shape[0]:
|
||
print(f"⚠️ {img_filename}: 点数不匹配 GT={len(gt_kpts)}, Pred={len(pred_kpts)},跳过")
|
||
skipped += 1
|
||
continue
|
||
|
||
# ====================== 只计算关键点 0 和 1 ======================
|
||
selected_idx = [0, 1] # 只算前两个点
|
||
|
||
errors = np.linalg.norm(pred_kpts[selected_idx] - gt_kpts[selected_idx], axis=1)
|
||
mean_error = np.mean(errors)
|
||
total_errors.append(mean_error)
|
||
|
||
print(f"📸 {img_filename}: 点0&1误差={np.round(errors, 2)} 像素, 平均误差={mean_error:.2f}px")
|
||
|
||
# 可视化
|
||
vis_img = img.copy()
|
||
vis_img = draw_keypoints_on_image(vis_img, gt_kpts, colors_gt, label_prefix='GT')
|
||
vis_img = draw_keypoints_on_image(vis_img, pred_kpts, colors_pred, label_prefix='P')
|
||
|
||
save_path = os.path.join(OUTPUT_DIR, f"compare_{img_filename}")
|
||
cv2.imwrite(save_path, vis_img)
|
||
|
||
# ====================== 结果统计 ======================
|
||
print("\n======================")
|
||
if total_errors:
|
||
print(f"🎯 有效样本数: {len(total_errors)} 张")
|
||
print(f"🚫 跳过样本数: {skipped} 张")
|
||
print(f"📈 平均关键点误差: {np.mean(total_errors):.2f} 像素")
|
||
else:
|
||
print(f"⚠️ 所有样本均被跳过(跳过 {skipped} 张)")
|
||
print("======================")
|