114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
import cv2
|
||
import numpy as np
|
||
from ultralytics import YOLO
|
||
import os
|
||
|
||
# ====================== 用户配置 ======================
|
||
MODEL_PATH = 'best.pt'
|
||
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251230'
|
||
OUTPUT_DIR = './keypoints_txt'
|
||
|
||
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
|
||
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# ====================== 保存函数 ======================
|
||
def save_yolo_kpts_with_bbox(save_path, cls_id, box_xyxy, kpts_xy, W, H):
|
||
"""
|
||
追加写入一行 YOLO 格式:cls xc yc w h x1 y1 2 x2 y2 2 ...
|
||
"""
|
||
x1, y1, x2, y2 = box_xyxy
|
||
xc = (x1 + x2) / 2.0 / W
|
||
yc = (y1 + y2) / 2.0 / H
|
||
w = (x2 - x1) / W
|
||
h = (y2 - y1) / H
|
||
|
||
line_items = [str(cls_id), f"{xc:.6f}", f"{yc:.6f}", f"{w:.6f}", f"{h:.6f}"]
|
||
|
||
for (x, y) in kpts_xy:
|
||
x_norm = x / W
|
||
y_norm = y / H
|
||
line_items.append(f"{x_norm:.6f}")
|
||
line_items.append(f"{y_norm:.6f}")
|
||
line_items.append("2") # visibility=2 表示可见
|
||
|
||
with open(save_path, "a") as f:
|
||
f.write(" ".join(line_items) + "\n")
|
||
|
||
print(f" 💾 写入: {save_path}")
|
||
|
||
|
||
# ====================== 主程序 ======================
|
||
if __name__ == "__main__":
|
||
print("🚀 开始 YOLO 检测框 + 关键点 TXT 输出")
|
||
|
||
model = YOLO(MODEL_PATH)
|
||
print(f"✅ 模型加载完成: {MODEL_PATH}")
|
||
|
||
image_files = [
|
||
f for f in os.listdir(IMAGE_SOURCE_DIR)
|
||
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
|
||
]
|
||
|
||
if not image_files:
|
||
print("❌ 未找到图像文件")
|
||
exit(1)
|
||
|
||
for img_filename in image_files:
|
||
print(f"\n🖼️ 正在处理: {img_filename}")
|
||
|
||
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
|
||
img = cv2.imread(img_path)
|
||
if img is None:
|
||
print("❌ 无法读取图像")
|
||
continue
|
||
|
||
H, W = img.shape[:2]
|
||
|
||
# 推理
|
||
results = model(img)
|
||
|
||
txt_name = os.path.splitext(img_filename)[0] + ".txt"
|
||
save_path = os.path.join(OUTPUT_DIR, txt_name)
|
||
|
||
# ✅ 关键修改:无论有没有检测结果,先创建(或清空)空文件
|
||
open(save_path, 'w').close() # 创建空文件,若存在则清空
|
||
print(f" 📄 初始化空标签文件: {save_path}")
|
||
|
||
has_detection = False
|
||
for result in results:
|
||
if result.boxes is None or len(result.boxes) == 0:
|
||
continue
|
||
|
||
boxes = result.boxes.xyxy.cpu().numpy()
|
||
classes = result.boxes.cls.cpu().numpy()
|
||
|
||
# 处理关键点(可能为 None)
|
||
if hasattr(result, 'keypoints') and result.keypoints is not None:
|
||
kpts_xy = result.keypoints.xy.cpu().numpy() # (N, K, 2)
|
||
else:
|
||
kpts_xy = [np.array([])] * len(boxes) # 无关键点时留空(但你的模型应有)
|
||
|
||
for i in range(len(boxes)):
|
||
cls_id = int(classes[i])
|
||
box = boxes[i]
|
||
kpts = kpts_xy[i] if len(kpts_xy) > i else np.array([])
|
||
|
||
# 如果关键点数量不符合预期(比如不是4个),可选择跳过或填充
|
||
# 这里假设模型输出固定数量关键点
|
||
if kpts.size == 0:
|
||
print(f" ⚠ 跳过无关键点的目标")
|
||
continue
|
||
|
||
save_yolo_kpts_with_bbox(save_path, cls_id, box, kpts, W, H)
|
||
has_detection = True
|
||
|
||
# 可选:提示是否检测到内容
|
||
if not has_detection:
|
||
print(f" ℹ️ 未检测到有效目标,保留空文件")
|
||
|
||
print("\n==============================================")
|
||
print("🎉 全部图像处理完毕!")
|
||
print(f"📁 YOLO 输出目录:{OUTPUT_DIR}")
|
||
print("✅ 所有图像均已生成对应 .txt 文件(含空文件)")
|
||
print("==============================================") |