2025-10-21 14:11:52 +08:00
|
|
|
|
import xml.etree.ElementTree as ET
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
# =================== 配置 ===================
|
|
|
|
|
|
xml_file = 'annotations.xml' # 你的 CVAT XML 文件路径
|
2026-03-10 13:58:21 +08:00
|
|
|
|
images_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226' # 图像文件夹(用于读取宽高)
|
2025-10-21 14:11:52 +08:00
|
|
|
|
output_dir = 'labels_keypoints' # 输出 YOLO 标签目录
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 类别映射(根据你的 XML 中的 label name)
|
|
|
|
|
|
class_mapping = {
|
|
|
|
|
|
'clamp1': 0,
|
|
|
|
|
|
'clamp0': 1,
|
|
|
|
|
|
'kongliao': 2,
|
|
|
|
|
|
'duiliao': 3
|
|
|
|
|
|
}
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
|
|
|
|
|
# 如果为 True,则没有目标时不创建 .txt 文件;如果为 False,则创建空内容的 .txt 文件。
|
|
|
|
|
|
skip_empty_images = False
|
2025-10-21 14:11:52 +08:00
|
|
|
|
# ============================================
|
|
|
|
|
|
|
|
|
|
|
|
def parse_points(points_str):
|
|
|
|
|
|
"""解析 CVAT 的 points 字符串,返回 [(x, y), ...]"""
|
|
|
|
|
|
return [(float(p.split(',')[0]), float(p.split(',')[1])) for p in points_str.split(';')]
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_bbox_and_kpts(image_w, image_h, bbox, keypoints):
|
|
|
|
|
|
"""归一化 bbox 和关键点"""
|
|
|
|
|
|
cx, cy, w, h = bbox
|
|
|
|
|
|
cx_n, cy_n, w_n, h_n = cx/image_w, cy/image_h, w/image_w, h/image_h
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
2025-10-21 14:11:52 +08:00
|
|
|
|
kpts_n = []
|
|
|
|
|
|
for x, y in keypoints:
|
|
|
|
|
|
kpts_n.append(x / image_w)
|
|
|
|
|
|
kpts_n.append(y / image_h)
|
|
|
|
|
|
kpts_n.append(2) # v=2: visible
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
2025-10-21 14:11:52 +08:00
|
|
|
|
return (cx_n, cy_n, w_n, h_n), kpts_n
|
|
|
|
|
|
|
|
|
|
|
|
def points_to_bbox(points):
|
|
|
|
|
|
"""从点集生成最小外接矩形 (x_center, y_center, width, height)"""
|
|
|
|
|
|
xs = [p[0] for p in points]
|
|
|
|
|
|
ys = [p[1] for p in points]
|
|
|
|
|
|
x_min, x_max = min(xs), max(xs)
|
|
|
|
|
|
y_min, y_max = min(ys), max(ys)
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
2025-10-21 14:11:52 +08:00
|
|
|
|
cx = (x_min + x_max) / 2
|
|
|
|
|
|
cy = (y_min + y_max) / 2
|
|
|
|
|
|
w = x_max - x_min
|
|
|
|
|
|
h = y_max - y_min
|
2025-12-11 08:37:09 +08:00
|
|
|
|
|
2025-10-21 14:11:52 +08:00
|
|
|
|
return cx, cy, w, h
|
|
|
|
|
|
|
|
|
|
|
|
# 解析 XML
|
|
|
|
|
|
tree = ET.parse(xml_file)
|
|
|
|
|
|
root = tree.getroot()
|
|
|
|
|
|
|
|
|
|
|
|
for image_elem in root.findall('image'):
|
|
|
|
|
|
image_name = image_elem.get('name')
|
|
|
|
|
|
image_w = int(image_elem.get('width'))
|
|
|
|
|
|
image_h = int(image_elem.get('height'))
|
|
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
# 查找关键点(如果没有则跳过)
|
|
|
|
|
|
points_elem = image_elem.find("points[@label='clamp1']")
|
|
|
|
|
|
if points_elem is None:
|
|
|
|
|
|
print(f"⚠️ 图像 {image_name} 缺少 clamp1 关键点")
|
|
|
|
|
|
if not skip_empty_images:
|
|
|
|
|
|
open(os.path.join(output_dir, os.path.splitext(image_name)[0] + '.txt'), 'w').close()
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
keypoints = parse_points(points_elem.get('points'))
|
2025-10-21 14:11:52 +08:00
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
if len(keypoints) != 4:
|
|
|
|
|
|
print(f"⚠️ 图像 {image_name} 关键点数量错误({len(keypoints)})")
|
|
|
|
|
|
if not skip_empty_images:
|
|
|
|
|
|
open(os.path.join(output_dir, os.path.splitext(image_name)[0] + '.txt'), 'w').close()
|
|
|
|
|
|
continue
|
2025-10-21 14:11:52 +08:00
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
# 生成包围框
|
|
|
|
|
|
bbox = points_to_bbox(keypoints)
|
2025-10-21 14:11:52 +08:00
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
# 归一化
|
|
|
|
|
|
(cx_n, cy_n, w_n, h_n), kpts_n = normalize_bbox_and_kpts(image_w, image_h, bbox, keypoints)
|
2025-10-21 14:11:52 +08:00
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
# 输出 label
|
|
|
|
|
|
label_file = os.path.splitext(image_name)[0] + '.txt'
|
|
|
|
|
|
label_path = os.path.join(output_dir, label_file)
|
2025-10-21 14:11:52 +08:00
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
with open(label_path, 'w') as f_out:
|
2025-10-21 14:11:52 +08:00
|
|
|
|
line = [str(class_mapping['clamp1']), str(cx_n), str(cy_n), str(w_n), str(h_n)] + [str(k) for k in kpts_n]
|
|
|
|
|
|
f_out.write(' '.join(line) + '\n')
|
|
|
|
|
|
|
2025-12-11 08:37:09 +08:00
|
|
|
|
print("🎉 关键点转换完成!(仅生成有效标注 txt 或者根据设置处理空标签)")
|
|
|
|
|
|
print(f"📂 输出目录: {output_dir}")
|