Files

114 lines
3.9 KiB
Python
Raw Permalink Normal View History

2025-12-11 08:37:09 +08:00
import cv2
import numpy as np
from ultralytics import YOLO
import os
# ====================== 用户配置 ======================
2026-03-10 13:58:21 +08:00
MODEL_PATH = 'best.pt'
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251230'
2025-12-11 08:37:09 +08:00
OUTPUT_DIR = './keypoints_txt'
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 保存函数 ======================
def save_yolo_kpts_with_bbox(save_path, cls_id, box_xyxy, kpts_xy, W, H):
"""
追加写入一行 YOLO 格式cls xc yc w h x1 y1 2 x2 y2 2 ...
"""
x1, y1, x2, y2 = box_xyxy
xc = (x1 + x2) / 2.0 / W
yc = (y1 + y2) / 2.0 / H
w = (x2 - x1) / W
h = (y2 - y1) / H
line_items = [str(cls_id), f"{xc:.6f}", f"{yc:.6f}", f"{w:.6f}", f"{h:.6f}"]
for (x, y) in kpts_xy:
x_norm = x / W
y_norm = y / H
line_items.append(f"{x_norm:.6f}")
line_items.append(f"{y_norm:.6f}")
line_items.append("2") # visibility=2 表示可见
with open(save_path, "a") as f:
f.write(" ".join(line_items) + "\n")
print(f" 💾 写入: {save_path}")
# ====================== 主程序 ======================
if __name__ == "__main__":
print("🚀 开始 YOLO 检测框 + 关键点 TXT 输出")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print("❌ 未找到图像文件")
exit(1)
for img_filename in image_files:
print(f"\n🖼️ 正在处理: {img_filename}")
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
img = cv2.imread(img_path)
if img is None:
print("❌ 无法读取图像")
continue
H, W = img.shape[:2]
# 推理
results = model(img)
txt_name = os.path.splitext(img_filename)[0] + ".txt"
save_path = os.path.join(OUTPUT_DIR, txt_name)
# ✅ 关键修改:无论有没有检测结果,先创建(或清空)空文件
open(save_path, 'w').close() # 创建空文件,若存在则清空
print(f" 📄 初始化空标签文件: {save_path}")
has_detection = False
for result in results:
if result.boxes is None or len(result.boxes) == 0:
continue
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
# 处理关键点(可能为 None
if hasattr(result, 'keypoints') and result.keypoints is not None:
kpts_xy = result.keypoints.xy.cpu().numpy() # (N, K, 2)
else:
kpts_xy = [np.array([])] * len(boxes) # 无关键点时留空(但你的模型应有)
for i in range(len(boxes)):
cls_id = int(classes[i])
box = boxes[i]
kpts = kpts_xy[i] if len(kpts_xy) > i else np.array([])
# 如果关键点数量不符合预期比如不是4个可选择跳过或填充
# 这里假设模型输出固定数量关键点
if kpts.size == 0:
print(f" ⚠ 跳过无关键点的目标")
continue
save_yolo_kpts_with_bbox(save_path, cls_id, box, kpts, W, H)
has_detection = True
# 可选:提示是否检测到内容
if not has_detection:
print(f" 未检测到有效目标,保留空文件")
print("\n==============================================")
print("🎉 全部图像处理完毕!")
print(f"📁 YOLO 输出目录:{OUTPUT_DIR}")
print("✅ 所有图像均已生成对应 .txt 文件(含空文件)")
print("==============================================")