Files
zjsh_yolov11/yolo11_point/trans_cvattopoint.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

83 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# convert_cvat_to_yolo_keypoints.py
import xml.etree.ElementTree as ET
import os
# =================== 配置 ===================
xml_file = 'annotations.xml' # 你的 CVAT XML 文件路径
images_dir = 'images' # 图像文件夹(用于读取宽高)
output_dir = 'labels_keypoints' # 输出 YOLO 标签目录
os.makedirs(output_dir, exist_ok=True)
# 类别映射(根据你的 XML 中的 label name
class_mapping = {
'clamp1': 0,
'clamp0': 1,
'kongliao': 2,
'duiliao': 3
}
# ============================================
def parse_points(points_str):
"""解析 CVAT 的 points 字符串,返回 [(x, y), ...]"""
return [(float(p.split(',')[0]), float(p.split(',')[1])) for p in points_str.split(';')]
def normalize_bbox_and_kpts(image_w, image_h, bbox, keypoints):
"""归一化 bbox 和关键点"""
cx, cy, w, h = bbox
cx_n, cy_n, w_n, h_n = cx/image_w, cy/image_h, w/image_w, h/image_h
kpts_n = []
for x, y in keypoints:
kpts_n.append(x / image_w)
kpts_n.append(y / image_h)
kpts_n.append(2) # v=2: visible
return (cx_n, cy_n, w_n, h_n), kpts_n
def points_to_bbox(points):
"""从点集生成最小外接矩形 (x_center, y_center, width, height)"""
xs = [p[0] for p in points]
ys = [p[1] for p in points]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
cx = (x_min + x_max) / 2
cy = (y_min + y_max) / 2
w = x_max - x_min
h = y_max - y_min
return cx, cy, w, h
# 解析 XML
tree = ET.parse(xml_file)
root = tree.getroot()
for image_elem in root.findall('image'):
image_name = image_elem.get('name')
image_w = int(image_elem.get('width'))
image_h = int(image_elem.get('height'))
label_file = os.path.splitext(image_name)[0] + '.txt'
label_path = os.path.join(output_dir, label_file)
with open(label_path, 'w') as f_out:
# 查找 points 标注
points_elem = image_elem.find("points[@label='clamp1']")
if points_elem is None:
print(f"⚠️ 警告:图像 {image_name} 缺少 clamp1 关键点标注")
continue
points_str = points_elem.get('points')
keypoints = parse_points(points_str)
if len(keypoints) != 4:
print(f"⚠️ 警告:图像 {image_name} 的关键点数量不是 4实际为 {len(keypoints)}")
continue
# 生成包围框
bbox = points_to_bbox(keypoints)
# 归一化
(cx_n, cy_n, w_n, h_n), kpts_n = normalize_bbox_and_kpts(image_w, image_h, bbox, keypoints)
# 写入 YOLO 格式
line = [str(class_mapping['clamp1']), str(cx_n), str(cy_n), str(w_n), str(h_n)] + [str(k) for k in kpts_n]
f_out.write(' '.join(line) + '\n')
print("✅ 所有标注已成功转换为 YOLO 关键点格式!")
print(f"📌 输出目录: {output_dir}")