Files
zjsh_yolov11/yolo11_point/trans_cvattopoint.py
琉璃月光 eb16eeada3 最新推送
2026-03-10 13:58:21 +08:00

93 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import xml.etree.ElementTree as ET
import os
# =================== 配置 ===================
xml_file = 'annotations.xml' # 你的 CVAT XML 文件路径
images_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/20251226' # 图像文件夹(用于读取宽高)
output_dir = 'labels_keypoints' # 输出 YOLO 标签目录
os.makedirs(output_dir, exist_ok=True)
# 类别映射(根据你的 XML 中的 label name
class_mapping = {
'clamp1': 0,
'clamp0': 1,
'kongliao': 2,
'duiliao': 3
}
# 如果为 True则没有目标时不创建 .txt 文件;如果为 False则创建空内容的 .txt 文件。
skip_empty_images = False
# ============================================
def parse_points(points_str):
"""解析 CVAT 的 points 字符串,返回 [(x, y), ...]"""
return [(float(p.split(',')[0]), float(p.split(',')[1])) for p in points_str.split(';')]
def normalize_bbox_and_kpts(image_w, image_h, bbox, keypoints):
"""归一化 bbox 和关键点"""
cx, cy, w, h = bbox
cx_n, cy_n, w_n, h_n = cx/image_w, cy/image_h, w/image_w, h/image_h
kpts_n = []
for x, y in keypoints:
kpts_n.append(x / image_w)
kpts_n.append(y / image_h)
kpts_n.append(2) # v=2: visible
return (cx_n, cy_n, w_n, h_n), kpts_n
def points_to_bbox(points):
"""从点集生成最小外接矩形 (x_center, y_center, width, height)"""
xs = [p[0] for p in points]
ys = [p[1] for p in points]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
cx = (x_min + x_max) / 2
cy = (y_min + y_max) / 2
w = x_max - x_min
h = y_max - y_min
return cx, cy, w, h
# 解析 XML
tree = ET.parse(xml_file)
root = tree.getroot()
for image_elem in root.findall('image'):
image_name = image_elem.get('name')
image_w = int(image_elem.get('width'))
image_h = int(image_elem.get('height'))
# 查找关键点(如果没有则跳过)
points_elem = image_elem.find("points[@label='clamp1']")
if points_elem is None:
print(f"⚠️ 图像 {image_name} 缺少 clamp1 关键点")
if not skip_empty_images:
open(os.path.join(output_dir, os.path.splitext(image_name)[0] + '.txt'), 'w').close()
continue
keypoints = parse_points(points_elem.get('points'))
if len(keypoints) != 4:
print(f"⚠️ 图像 {image_name} 关键点数量错误({len(keypoints)}")
if not skip_empty_images:
open(os.path.join(output_dir, os.path.splitext(image_name)[0] + '.txt'), 'w').close()
continue
# 生成包围框
bbox = points_to_bbox(keypoints)
# 归一化
(cx_n, cy_n, w_n, h_n), kpts_n = normalize_bbox_and_kpts(image_w, image_h, bbox, keypoints)
# 输出 label
label_file = os.path.splitext(image_name)[0] + '.txt'
label_path = os.path.join(output_dir, label_file)
with open(label_path, 'w') as f_out:
line = [str(class_mapping['clamp1']), str(cx_n), str(cy_n), str(w_n), str(h_n)] + [str(k) for k in kpts_n]
f_out.write(' '.join(line) + '\n')
print("🎉 关键点转换完成!(仅生成有效标注 txt 或者根据设置处理空标签)")
print(f"📂 输出目录: {output_dir}")