import cv2 import numpy as np from ultralytics import YOLO import os # ====================== 用户配置 ====================== MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_pose/weights/best.pt' IMAGE_PATH = '/output_masks/3.png' OUTPUT_DIR = '/home/hx/yolo/output_images' # 保存结果的文件夹 os.makedirs(OUTPUT_DIR, exist_ok=True) # ====================== 可视化函数 ====================== def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape): """ 在图像上绘制关键点 :param image: OpenCV 图像 :param kpts_xy: (N, K, 2) 坐标 :param kpts_conf: (N, K) 置信度 :param orig_shape: 原图尺寸 (H, W) """ colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青 for i in range(len(kpts_xy)): xy = kpts_xy[i] # (4, 2) conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量 for j in range(len(xy)): x, y = xy[j] c = conf[j] if hasattr(conf, '__len__') else conf x, y = int(x), int(y) # 检查坐标是否在图像范围内 if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]: continue # 只绘制置信度 > 0.5 的点 if c < 0.5: continue # 绘制圆圈 cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1) # 标注编号 cv2.putText(image, f'{j+1}', (x + 20, y - 20), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5) return image # ====================== 主程序 ====================== if __name__ == "__main__": print("正在加载模型...") model = YOLO(MODEL_PATH) print(f"✅ 模型加载完成: {MODEL_PATH}") img = cv2.imread(IMAGE_PATH) if img is None: raise FileNotFoundError(f"无法加载图像: {IMAGE_PATH}") print(f"✅ 图像加载成功: {IMAGE_PATH} (shape: {img.shape})") print("\n开始推理...") results = model(img) for i, result in enumerate(results): if result.keypoints is not None: kpts = result.keypoints orig_shape = kpts.orig_shape # (H, W) # ✅ 正确获取坐标和置信度 kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2) kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2]) print(f"\n 结果 {i + 1}:") print(f" 检测到 {len(kpts_xy)} 个实例") print(f" 坐标形状: {kpts_xy.shape}") print(f" 置信度形状: {kpts_conf.shape}") # 绘图 img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape) # 保存 base_name = os.path.basename(IMAGE_PATH) save_path = os.path.join(OUTPUT_DIR, f"keypoints_{base_name}") cv2.imwrite(save_path, img_with_kpts) print(f"✅ 结果已保存: {save_path}") # 显示(可选) display_img = cv2.resize(img_with_kpts, (1280, 720)) # 缩小显示 cv2.imshow("Keypoints", display_img) print("按任意键关闭...") cv2.waitKey(0) cv2.destroyAllWindows() else: print(" ❌ 未检测到关键点")