2025-12-11 08:37:09 +08:00
|
|
|
|
import cv2
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 用户配置 ======================
|
2026-03-10 13:58:21 +08:00
|
|
|
|
MODEL_PATH = 'best.pt'
|
|
|
|
|
|
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251230'
|
2025-12-11 08:37:09 +08:00
|
|
|
|
OUTPUT_DIR = './keypoints_txt'
|
|
|
|
|
|
|
|
|
|
|
|
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 保存函数 ======================
|
|
|
|
|
|
def save_yolo_kpts_with_bbox(save_path, cls_id, box_xyxy, kpts_xy, W, H):
|
|
|
|
|
|
"""
|
|
|
|
|
|
追加写入一行 YOLO 格式:cls xc yc w h x1 y1 2 x2 y2 2 ...
|
|
|
|
|
|
"""
|
|
|
|
|
|
x1, y1, x2, y2 = box_xyxy
|
|
|
|
|
|
xc = (x1 + x2) / 2.0 / W
|
|
|
|
|
|
yc = (y1 + y2) / 2.0 / H
|
|
|
|
|
|
w = (x2 - x1) / W
|
|
|
|
|
|
h = (y2 - y1) / H
|
|
|
|
|
|
|
|
|
|
|
|
line_items = [str(cls_id), f"{xc:.6f}", f"{yc:.6f}", f"{w:.6f}", f"{h:.6f}"]
|
|
|
|
|
|
|
|
|
|
|
|
for (x, y) in kpts_xy:
|
|
|
|
|
|
x_norm = x / W
|
|
|
|
|
|
y_norm = y / H
|
|
|
|
|
|
line_items.append(f"{x_norm:.6f}")
|
|
|
|
|
|
line_items.append(f"{y_norm:.6f}")
|
|
|
|
|
|
line_items.append("2") # visibility=2 表示可见
|
|
|
|
|
|
|
|
|
|
|
|
with open(save_path, "a") as f:
|
|
|
|
|
|
f.write(" ".join(line_items) + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
print(f" 💾 写入: {save_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== 主程序 ======================
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("🚀 开始 YOLO 检测框 + 关键点 TXT 输出")
|
|
|
|
|
|
|
|
|
|
|
|
model = YOLO(MODEL_PATH)
|
|
|
|
|
|
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
|
|
|
|
|
|
|
|
|
|
|
image_files = [
|
|
|
|
|
|
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
|
|
|
|
|
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if not image_files:
|
|
|
|
|
|
print("❌ 未找到图像文件")
|
|
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
for img_filename in image_files:
|
|
|
|
|
|
print(f"\n🖼️ 正在处理: {img_filename}")
|
|
|
|
|
|
|
|
|
|
|
|
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
|
|
|
|
|
img = cv2.imread(img_path)
|
|
|
|
|
|
if img is None:
|
|
|
|
|
|
print("❌ 无法读取图像")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
H, W = img.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
# 推理
|
|
|
|
|
|
results = model(img)
|
|
|
|
|
|
|
|
|
|
|
|
txt_name = os.path.splitext(img_filename)[0] + ".txt"
|
|
|
|
|
|
save_path = os.path.join(OUTPUT_DIR, txt_name)
|
|
|
|
|
|
|
|
|
|
|
|
# ✅ 关键修改:无论有没有检测结果,先创建(或清空)空文件
|
|
|
|
|
|
open(save_path, 'w').close() # 创建空文件,若存在则清空
|
|
|
|
|
|
print(f" 📄 初始化空标签文件: {save_path}")
|
|
|
|
|
|
|
|
|
|
|
|
has_detection = False
|
|
|
|
|
|
for result in results:
|
|
|
|
|
|
if result.boxes is None or len(result.boxes) == 0:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
boxes = result.boxes.xyxy.cpu().numpy()
|
|
|
|
|
|
classes = result.boxes.cls.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
# 处理关键点(可能为 None)
|
|
|
|
|
|
if hasattr(result, 'keypoints') and result.keypoints is not None:
|
|
|
|
|
|
kpts_xy = result.keypoints.xy.cpu().numpy() # (N, K, 2)
|
|
|
|
|
|
else:
|
|
|
|
|
|
kpts_xy = [np.array([])] * len(boxes) # 无关键点时留空(但你的模型应有)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(boxes)):
|
|
|
|
|
|
cls_id = int(classes[i])
|
|
|
|
|
|
box = boxes[i]
|
|
|
|
|
|
kpts = kpts_xy[i] if len(kpts_xy) > i else np.array([])
|
|
|
|
|
|
|
|
|
|
|
|
# 如果关键点数量不符合预期(比如不是4个),可选择跳过或填充
|
|
|
|
|
|
# 这里假设模型输出固定数量关键点
|
|
|
|
|
|
if kpts.size == 0:
|
|
|
|
|
|
print(f" ⚠ 跳过无关键点的目标")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
save_yolo_kpts_with_bbox(save_path, cls_id, box, kpts, W, H)
|
|
|
|
|
|
has_detection = True
|
|
|
|
|
|
|
|
|
|
|
|
# 可选:提示是否检测到内容
|
|
|
|
|
|
if not has_detection:
|
|
|
|
|
|
print(f" ℹ️ 未检测到有效目标,保留空文件")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n==============================================")
|
|
|
|
|
|
print("🎉 全部图像处理完毕!")
|
|
|
|
|
|
print(f"📁 YOLO 输出目录:{OUTPUT_DIR}")
|
|
|
|
|
|
print("✅ 所有图像均已生成对应 .txt 文件(含空文件)")
|
|
|
|
|
|
print("==============================================")
|