import os import xml.etree.ElementTree as ET from pathlib import Path def cvat_to_yolo_seg( xml_path, output_dir, class_name_to_id=None, force_class_id=None ): """ 将 CVAT 导出的 XML(polygon / segmentation)转换为 YOLO Segmentation 格式 Args: xml_path (str): CVAT 导出的 XML 文件路径 output_dir (str): 输出 .txt 标注文件的目录 class_name_to_id (dict, optional): 类别名到 ID 的映射 force_class_id (dict, optional): 强制指定某些类别的 ID,例如 {"yemian": 0} """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) tree = ET.parse(xml_path) root = tree.getroot() # ---------------------------- # 自动提取类别映射 # ---------------------------- if class_name_to_id is None: class_name_to_id = {} labels_elem = root.find(".//labels") if labels_elem is not None: for idx, label in enumerate(labels_elem.findall("label")): name = label.find("name").text class_name_to_id[name] = idx else: raise RuntimeError("❌ 未找到 ,请手动提供 class_name_to_id") print(f"原始类别映射: {class_name_to_id}") # ---------------------------- # 强制修改类别 ID(新增功能) # ---------------------------- if force_class_id: for name, new_id in force_class_id.items(): if name in class_name_to_id: old_id = class_name_to_id[name] class_name_to_id[name] = new_id print(f"强制修改类别映射: {name} {old_id} → {new_id}") else: print(f"类别 {name} 不存在,跳过") print(f"最终类别映射: {class_name_to_id}") # ---------------------------- # 遍历每一张 image # ---------------------------- for image in root.findall("image"): img_name = image.get("name") width = int(image.get("width")) height = int(image.get("height")) txt_path = output_dir / (Path(img_name).stem + ".txt") lines = [] for polygon in image.findall("polygon"): label = polygon.get("label") if label not in class_name_to_id: continue class_id = class_name_to_id[label] points_str = polygon.get("points") points = [] for p in points_str.strip().split(";"): x, y = p.split(",") x = float(x) / width y = float(y) / height points.append(f"{x:.6f}") points.append(f"{y:.6f}") if len(points) < 6: continue lines.append(f"{class_id} " + " ".join(points)) with open(txt_path, "w", encoding="utf-8") as f: f.write("\n".join(lines)) print("✅ CVAT segmentation → YOLO seg 转换完成") if __name__ == "__main__": cvat_to_yolo_seg( xml_path="annotations.xml", output_dir="labels_seg", force_class_id={ "yemian": 0 } )