import cv2 import numpy as np from ultralytics import YOLO import os # ====================== 用户配置 ====================== MODEL_PATH = 'best.pt' IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251230' OUTPUT_DIR = './keypoints_txt' IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'} os.makedirs(OUTPUT_DIR, exist_ok=True) # ====================== 保存函数 ====================== def save_yolo_kpts_with_bbox(save_path, cls_id, box_xyxy, kpts_xy, W, H): """ 追加写入一行 YOLO 格式:cls xc yc w h x1 y1 2 x2 y2 2 ... """ x1, y1, x2, y2 = box_xyxy xc = (x1 + x2) / 2.0 / W yc = (y1 + y2) / 2.0 / H w = (x2 - x1) / W h = (y2 - y1) / H line_items = [str(cls_id), f"{xc:.6f}", f"{yc:.6f}", f"{w:.6f}", f"{h:.6f}"] for (x, y) in kpts_xy: x_norm = x / W y_norm = y / H line_items.append(f"{x_norm:.6f}") line_items.append(f"{y_norm:.6f}") line_items.append("2") # visibility=2 表示可见 with open(save_path, "a") as f: f.write(" ".join(line_items) + "\n") print(f" 💾 写入: {save_path}") # ====================== 主程序 ====================== if __name__ == "__main__": print("🚀 开始 YOLO 检测框 + 关键点 TXT 输出") model = YOLO(MODEL_PATH) print(f"✅ 模型加载完成: {MODEL_PATH}") image_files = [ f for f in os.listdir(IMAGE_SOURCE_DIR) if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS ] if not image_files: print("❌ 未找到图像文件") exit(1) for img_filename in image_files: print(f"\n🖼️ 正在处理: {img_filename}") img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename) img = cv2.imread(img_path) if img is None: print("❌ 无法读取图像") continue H, W = img.shape[:2] # 推理 results = model(img) txt_name = os.path.splitext(img_filename)[0] + ".txt" save_path = os.path.join(OUTPUT_DIR, txt_name) # ✅ 关键修改:无论有没有检测结果,先创建(或清空)空文件 open(save_path, 'w').close() # 创建空文件,若存在则清空 print(f" 📄 初始化空标签文件: {save_path}") has_detection = False for result in results: if result.boxes is None or len(result.boxes) == 0: continue boxes = result.boxes.xyxy.cpu().numpy() classes = result.boxes.cls.cpu().numpy() # 处理关键点(可能为 None) if hasattr(result, 'keypoints') and result.keypoints is not None: kpts_xy = result.keypoints.xy.cpu().numpy() # (N, K, 2) else: kpts_xy = [np.array([])] * len(boxes) # 无关键点时留空(但你的模型应有) for i in range(len(boxes)): cls_id = int(classes[i]) box = boxes[i] kpts = kpts_xy[i] if len(kpts_xy) > i else np.array([]) # 如果关键点数量不符合预期(比如不是4个),可选择跳过或填充 # 这里假设模型输出固定数量关键点 if kpts.size == 0: print(f" ⚠ 跳过无关键点的目标") continue save_yolo_kpts_with_bbox(save_path, cls_id, box, kpts, W, H) has_detection = True # 可选:提示是否检测到内容 if not has_detection: print(f" ℹ️ 未检测到有效目标,保留空文件") print("\n==============================================") print("🎉 全部图像处理完毕!") print(f"📁 YOLO 输出目录:{OUTPUT_DIR}") print("✅ 所有图像均已生成对应 .txt 文件(含空文件)") print("==============================================")