Files
zjsh_yolov11/yolo11_seg/cvattoseg.py

107 lines
3.1 KiB
Python
Raw Normal View History

2026-03-10 13:58:21 +08:00
import os
import xml.etree.ElementTree as ET
from pathlib import Path
def cvat_to_yolo_seg(
xml_path,
output_dir,
class_name_to_id=None,
force_class_id=None
):
"""
CVAT 导出的 XMLpolygon / segmentation转换为 YOLO Segmentation 格式
Args:
xml_path (str): CVAT 导出的 XML 文件路径
output_dir (str): 输出 .txt 标注文件的目录
class_name_to_id (dict, optional):
类别名到 ID 的映射
force_class_id (dict, optional):
强制指定某些类别的 ID例如 {"yemian": 0}
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tree = ET.parse(xml_path)
root = tree.getroot()
# ----------------------------
# 自动提取类别映射
# ----------------------------
if class_name_to_id is None:
class_name_to_id = {}
labels_elem = root.find(".//labels")
if labels_elem is not None:
for idx, label in enumerate(labels_elem.findall("label")):
name = label.find("name").text
class_name_to_id[name] = idx
else:
raise RuntimeError("❌ 未找到 <labels>,请手动提供 class_name_to_id")
print(f"原始类别映射: {class_name_to_id}")
# ----------------------------
# 强制修改类别 ID新增功能
# ----------------------------
if force_class_id:
for name, new_id in force_class_id.items():
if name in class_name_to_id:
old_id = class_name_to_id[name]
class_name_to_id[name] = new_id
print(f"强制修改类别映射: {name} {old_id}{new_id}")
else:
print(f"类别 {name} 不存在,跳过")
print(f"最终类别映射: {class_name_to_id}")
# ----------------------------
# 遍历每一张 image
# ----------------------------
for image in root.findall("image"):
img_name = image.get("name")
width = int(image.get("width"))
height = int(image.get("height"))
txt_path = output_dir / (Path(img_name).stem + ".txt")
lines = []
for polygon in image.findall("polygon"):
label = polygon.get("label")
if label not in class_name_to_id:
continue
class_id = class_name_to_id[label]
points_str = polygon.get("points")
points = []
for p in points_str.strip().split(";"):
x, y = p.split(",")
x = float(x) / width
y = float(y) / height
points.append(f"{x:.6f}")
points.append(f"{y:.6f}")
if len(points) < 6:
continue
lines.append(f"{class_id} " + " ".join(points))
with open(txt_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
print("✅ CVAT segmentation → YOLO seg 转换完成")
if __name__ == "__main__":
cvat_to_yolo_seg(
xml_path="annotations.xml",
output_dir="labels_seg",
force_class_id={
"yemian": 0
}
)