107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
|
||
import os
|
||
import xml.etree.ElementTree as ET
|
||
from pathlib import Path
|
||
|
||
|
||
def cvat_to_yolo_seg(
|
||
xml_path,
|
||
output_dir,
|
||
class_name_to_id=None,
|
||
force_class_id=None
|
||
):
|
||
"""
|
||
将 CVAT 导出的 XML(polygon / segmentation)转换为 YOLO Segmentation 格式
|
||
|
||
Args:
|
||
xml_path (str): CVAT 导出的 XML 文件路径
|
||
output_dir (str): 输出 .txt 标注文件的目录
|
||
class_name_to_id (dict, optional):
|
||
类别名到 ID 的映射
|
||
force_class_id (dict, optional):
|
||
强制指定某些类别的 ID,例如 {"yemian": 0}
|
||
"""
|
||
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
tree = ET.parse(xml_path)
|
||
root = tree.getroot()
|
||
|
||
# ----------------------------
|
||
# 自动提取类别映射
|
||
# ----------------------------
|
||
if class_name_to_id is None:
|
||
class_name_to_id = {}
|
||
labels_elem = root.find(".//labels")
|
||
if labels_elem is not None:
|
||
for idx, label in enumerate(labels_elem.findall("label")):
|
||
name = label.find("name").text
|
||
class_name_to_id[name] = idx
|
||
else:
|
||
raise RuntimeError("❌ 未找到 <labels>,请手动提供 class_name_to_id")
|
||
|
||
print(f"原始类别映射: {class_name_to_id}")
|
||
|
||
# ----------------------------
|
||
# 强制修改类别 ID(新增功能)
|
||
# ----------------------------
|
||
if force_class_id:
|
||
for name, new_id in force_class_id.items():
|
||
if name in class_name_to_id:
|
||
old_id = class_name_to_id[name]
|
||
class_name_to_id[name] = new_id
|
||
print(f"强制修改类别映射: {name} {old_id} → {new_id}")
|
||
else:
|
||
print(f"类别 {name} 不存在,跳过")
|
||
|
||
print(f"最终类别映射: {class_name_to_id}")
|
||
|
||
# ----------------------------
|
||
# 遍历每一张 image
|
||
# ----------------------------
|
||
for image in root.findall("image"):
|
||
img_name = image.get("name")
|
||
width = int(image.get("width"))
|
||
height = int(image.get("height"))
|
||
|
||
txt_path = output_dir / (Path(img_name).stem + ".txt")
|
||
lines = []
|
||
|
||
for polygon in image.findall("polygon"):
|
||
label = polygon.get("label")
|
||
if label not in class_name_to_id:
|
||
continue
|
||
|
||
class_id = class_name_to_id[label]
|
||
points_str = polygon.get("points")
|
||
|
||
points = []
|
||
for p in points_str.strip().split(";"):
|
||
x, y = p.split(",")
|
||
x = float(x) / width
|
||
y = float(y) / height
|
||
points.append(f"{x:.6f}")
|
||
points.append(f"{y:.6f}")
|
||
|
||
if len(points) < 6:
|
||
continue
|
||
|
||
lines.append(f"{class_id} " + " ".join(points))
|
||
|
||
with open(txt_path, "w", encoding="utf-8") as f:
|
||
f.write("\n".join(lines))
|
||
|
||
print("✅ CVAT segmentation → YOLO seg 转换完成")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
cvat_to_yolo_seg(
|
||
xml_path="annotations.xml",
|
||
output_dir="labels_seg",
|
||
force_class_id={
|
||
"yemian": 0
|
||
}
|
||
)
|