Files
zjsh_yolov11/yolo11_seg/cvattoseg.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

107 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import xml.etree.ElementTree as ET
from pathlib import Path
def cvat_to_yolo_seg(
xml_path,
output_dir,
class_name_to_id=None,
force_class_id=None
):
"""
将 CVAT 导出的 XMLpolygon / segmentation转换为 YOLO Segmentation 格式
Args:
xml_path (str): CVAT 导出的 XML 文件路径
output_dir (str): 输出 .txt 标注文件的目录
class_name_to_id (dict, optional):
类别名到 ID 的映射
force_class_id (dict, optional):
强制指定某些类别的 ID例如 {"yemian": 0}
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tree = ET.parse(xml_path)
root = tree.getroot()
# ----------------------------
# 自动提取类别映射
# ----------------------------
if class_name_to_id is None:
class_name_to_id = {}
labels_elem = root.find(".//labels")
if labels_elem is not None:
for idx, label in enumerate(labels_elem.findall("label")):
name = label.find("name").text
class_name_to_id[name] = idx
else:
raise RuntimeError("❌ 未找到 <labels>,请手动提供 class_name_to_id")
print(f"原始类别映射: {class_name_to_id}")
# ----------------------------
# 强制修改类别 ID新增功能
# ----------------------------
if force_class_id:
for name, new_id in force_class_id.items():
if name in class_name_to_id:
old_id = class_name_to_id[name]
class_name_to_id[name] = new_id
print(f"强制修改类别映射: {name} {old_id}{new_id}")
else:
print(f"类别 {name} 不存在,跳过")
print(f"最终类别映射: {class_name_to_id}")
# ----------------------------
# 遍历每一张 image
# ----------------------------
for image in root.findall("image"):
img_name = image.get("name")
width = int(image.get("width"))
height = int(image.get("height"))
txt_path = output_dir / (Path(img_name).stem + ".txt")
lines = []
for polygon in image.findall("polygon"):
label = polygon.get("label")
if label not in class_name_to_id:
continue
class_id = class_name_to_id[label]
points_str = polygon.get("points")
points = []
for p in points_str.strip().split(";"):
x, y = p.split(",")
x = float(x) / width
y = float(y) / height
points.append(f"{x:.6f}")
points.append(f"{y:.6f}")
if len(points) < 6:
continue
lines.append(f"{class_id} " + " ".join(points))
with open(txt_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
print("✅ CVAT segmentation → YOLO seg 转换完成")
if __name__ == "__main__":
cvat_to_yolo_seg(
xml_path="annotations.xml",
output_dir="labels_seg",
force_class_id={
"yemian": 0
}
)