import cv2 import numpy as np from ultralytics import YOLO import os # ====================== 用户配置 ====================== MODEL_PATH = '/home/hx/yolo/ultralytics_yolo11-main/runs/train/exp_ailai/weights/best.pt' IMAGE_SOURCE_DIR = '/home/hx/开发/ailai_image_obb/ailai_pc/test' # 👈 修改为你的图像文件夹路径 OUTPUT_DIR = './output_images' # 保存结果的文件夹 # 支持的图像扩展名 IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'} os.makedirs(OUTPUT_DIR, exist_ok=True) # ====================== 可视化函数 ====================== def draw_keypoints_on_image(image, kpts_xy, kpts_conf, orig_shape): """ 在图像上绘制关键点 :param image: OpenCV 图像 :param kpts_xy: (N, K, 2) 坐标 :param kpts_conf: (N, K) 置信度 :param orig_shape: 原图尺寸 (H, W) """ colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0)] # 1红, 2蓝, 3绿, 4青 for i in range(len(kpts_xy)): xy = kpts_xy[i] # (4, 2) conf = kpts_conf[i] if kpts_conf.ndim == 2 else kpts_conf[i:i+1] # (4,) 或标量 for j in range(len(xy)): x, y = xy[j] c = conf[j] if hasattr(conf, '__len__') else conf x, y = int(x), int(y) # 检查坐标是否在图像范围内 if x < 0 or y < 0 or x >= orig_shape[1] or y >= orig_shape[0]: continue # 只绘制置信度 > 0.5 的点 if c < 0.5: continue # 绘制实心圆 cv2.circle(image, (x, y), radius=15, color=colors[j], thickness=-1) # 标注编号(偏移避免遮挡) cv2.putText(image, f'{j+1}', (x + 20, y - 20), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colors[j], 5) return image # ====================== 主程序 ====================== if __name__ == "__main__": print("🚀 开始批量关键点检测任务") # 加载模型 print("🔄 加载 YOLO 模型...") model = YOLO(MODEL_PATH) print(f"✅ 模型加载完成: {MODEL_PATH}") # 获取所有图像文件 image_files = [ f for f in os.listdir(IMAGE_SOURCE_DIR) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS ] if not image_files: print(f"❌ 错误:在 {IMAGE_SOURCE_DIR} 中未找到支持的图像文件") exit(1) print(f"📁 发现 {len(image_files)} 张图像待处理") # 遍历每张图像 for img_filename in image_files: img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename) print(f"\n🖼️ 正在处理: {img_filename}") # 读取图像 img = cv2.imread(img_path) if img is None: print(f"❌ 无法读取图像,跳过: {img_path}") continue print(f" ✅ 图像加载成功 (shape: {img.shape})") # 推理 print(" 🔍 正在推理...") results = model(img) processed = False # 标记是否处理了关键点 for i, result in enumerate(results): if result.keypoints is not None: kpts = result.keypoints orig_shape = kpts.orig_shape # (H, W) # 获取坐标和置信度 kpts_xy = kpts.xy.cpu().numpy() # (N, K, 2) kpts_conf = kpts.conf.cpu().numpy() if kpts.conf is not None else np.ones(kpts_xy.shape[:2]) print(f" ✅ 检测到 {len(kpts_xy)} 个实例") # 绘制关键点 img_with_kpts = draw_keypoints_on_image(img.copy(), kpts_xy, kpts_conf, orig_shape) # 保存图像 save_filename = f"keypoints_{img_filename}" save_path = os.path.join(OUTPUT_DIR, save_filename) cv2.imwrite(save_path, img_with_kpts) print(f" 💾 结果已保存: {save_path}") # 可选:显示图像(每次一张,按任意键继续) # display_img = cv2.resize(img_with_kpts, (1280, 720)) # cv2.imshow("Keypoints Detection", display_img) # print(" ⌨️ 按任意键继续...") # cv2.waitKey(0) # cv2.destroyAllWindows() processed = True if not processed: print(f" ❌ 未检测到关键点,跳过保存") print("\n" + "=" * 60) print("🎉 批量推理完成!") print(f"📊 总共处理 {len(image_files)} 张图像") print(f"📁 结果保存在: {OUTPUT_DIR}") print("=" * 60)