Files
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

114 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from ultralytics import YOLO
import os
# ====================== 用户配置 ======================
MODEL_PATH = 'best.pt'
IMAGE_SOURCE_DIR = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251230'
OUTPUT_DIR = './keypoints_txt'
IMG_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 保存函数 ======================
def save_yolo_kpts_with_bbox(save_path, cls_id, box_xyxy, kpts_xy, W, H):
"""
追加写入一行 YOLO 格式cls xc yc w h x1 y1 2 x2 y2 2 ...
"""
x1, y1, x2, y2 = box_xyxy
xc = (x1 + x2) / 2.0 / W
yc = (y1 + y2) / 2.0 / H
w = (x2 - x1) / W
h = (y2 - y1) / H
line_items = [str(cls_id), f"{xc:.6f}", f"{yc:.6f}", f"{w:.6f}", f"{h:.6f}"]
for (x, y) in kpts_xy:
x_norm = x / W
y_norm = y / H
line_items.append(f"{x_norm:.6f}")
line_items.append(f"{y_norm:.6f}")
line_items.append("2") # visibility=2 表示可见
with open(save_path, "a") as f:
f.write(" ".join(line_items) + "\n")
print(f" 💾 写入: {save_path}")
# ====================== 主程序 ======================
if __name__ == "__main__":
print("🚀 开始 YOLO 检测框 + 关键点 TXT 输出")
model = YOLO(MODEL_PATH)
print(f"✅ 模型加载完成: {MODEL_PATH}")
image_files = [
f for f in os.listdir(IMAGE_SOURCE_DIR)
if os.path.splitext(f.lower())[1] in IMG_EXTENSIONS
]
if not image_files:
print("❌ 未找到图像文件")
exit(1)
for img_filename in image_files:
print(f"\n🖼️ 正在处理: {img_filename}")
img_path = os.path.join(IMAGE_SOURCE_DIR, img_filename)
img = cv2.imread(img_path)
if img is None:
print("❌ 无法读取图像")
continue
H, W = img.shape[:2]
# 推理
results = model(img)
txt_name = os.path.splitext(img_filename)[0] + ".txt"
save_path = os.path.join(OUTPUT_DIR, txt_name)
# ✅ 关键修改:无论有没有检测结果,先创建(或清空)空文件
open(save_path, 'w').close() # 创建空文件,若存在则清空
print(f" 📄 初始化空标签文件: {save_path}")
has_detection = False
for result in results:
if result.boxes is None or len(result.boxes) == 0:
continue
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
# 处理关键点(可能为 None
if hasattr(result, 'keypoints') and result.keypoints is not None:
kpts_xy = result.keypoints.xy.cpu().numpy() # (N, K, 2)
else:
kpts_xy = [np.array([])] * len(boxes) # 无关键点时留空(但你的模型应有)
for i in range(len(boxes)):
cls_id = int(classes[i])
box = boxes[i]
kpts = kpts_xy[i] if len(kpts_xy) > i else np.array([])
# 如果关键点数量不符合预期比如不是4个可选择跳过或填充
# 这里假设模型输出固定数量关键点
if kpts.size == 0:
print(f" ⚠ 跳过无关键点的目标")
continue
save_yolo_kpts_with_bbox(save_path, cls_id, box, kpts, W, H)
has_detection = True
# 可选:提示是否检测到内容
if not has_detection:
print(f" 未检测到有效目标,保留空文件")
print("\n==============================================")
print("🎉 全部图像处理完毕!")
print(f"📁 YOLO 输出目录:{OUTPUT_DIR}")
print("✅ 所有图像均已生成对应 .txt 文件(含空文件)")
print("==============================================")