107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
|
|
|
|||
|
|
import os
|
|||
|
|
import xml.etree.ElementTree as ET
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cvat_to_yolo_seg(
|
|||
|
|
xml_path,
|
|||
|
|
output_dir,
|
|||
|
|
class_name_to_id=None,
|
|||
|
|
force_class_id=None
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
将 CVAT 导出的 XML(polygon / segmentation)转换为 YOLO Segmentation 格式
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
xml_path (str): CVAT 导出的 XML 文件路径
|
|||
|
|
output_dir (str): 输出 .txt 标注文件的目录
|
|||
|
|
class_name_to_id (dict, optional):
|
|||
|
|
类别名到 ID 的映射
|
|||
|
|
force_class_id (dict, optional):
|
|||
|
|
强制指定某些类别的 ID,例如 {"yemian": 0}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
output_dir = Path(output_dir)
|
|||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
tree = ET.parse(xml_path)
|
|||
|
|
root = tree.getroot()
|
|||
|
|
|
|||
|
|
# ----------------------------
|
|||
|
|
# 自动提取类别映射
|
|||
|
|
# ----------------------------
|
|||
|
|
if class_name_to_id is None:
|
|||
|
|
class_name_to_id = {}
|
|||
|
|
labels_elem = root.find(".//labels")
|
|||
|
|
if labels_elem is not None:
|
|||
|
|
for idx, label in enumerate(labels_elem.findall("label")):
|
|||
|
|
name = label.find("name").text
|
|||
|
|
class_name_to_id[name] = idx
|
|||
|
|
else:
|
|||
|
|
raise RuntimeError("❌ 未找到 <labels>,请手动提供 class_name_to_id")
|
|||
|
|
|
|||
|
|
print(f"原始类别映射: {class_name_to_id}")
|
|||
|
|
|
|||
|
|
# ----------------------------
|
|||
|
|
# 强制修改类别 ID(新增功能)
|
|||
|
|
# ----------------------------
|
|||
|
|
if force_class_id:
|
|||
|
|
for name, new_id in force_class_id.items():
|
|||
|
|
if name in class_name_to_id:
|
|||
|
|
old_id = class_name_to_id[name]
|
|||
|
|
class_name_to_id[name] = new_id
|
|||
|
|
print(f"强制修改类别映射: {name} {old_id} → {new_id}")
|
|||
|
|
else:
|
|||
|
|
print(f"类别 {name} 不存在,跳过")
|
|||
|
|
|
|||
|
|
print(f"最终类别映射: {class_name_to_id}")
|
|||
|
|
|
|||
|
|
# ----------------------------
|
|||
|
|
# 遍历每一张 image
|
|||
|
|
# ----------------------------
|
|||
|
|
for image in root.findall("image"):
|
|||
|
|
img_name = image.get("name")
|
|||
|
|
width = int(image.get("width"))
|
|||
|
|
height = int(image.get("height"))
|
|||
|
|
|
|||
|
|
txt_path = output_dir / (Path(img_name).stem + ".txt")
|
|||
|
|
lines = []
|
|||
|
|
|
|||
|
|
for polygon in image.findall("polygon"):
|
|||
|
|
label = polygon.get("label")
|
|||
|
|
if label not in class_name_to_id:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
class_id = class_name_to_id[label]
|
|||
|
|
points_str = polygon.get("points")
|
|||
|
|
|
|||
|
|
points = []
|
|||
|
|
for p in points_str.strip().split(";"):
|
|||
|
|
x, y = p.split(",")
|
|||
|
|
x = float(x) / width
|
|||
|
|
y = float(y) / height
|
|||
|
|
points.append(f"{x:.6f}")
|
|||
|
|
points.append(f"{y:.6f}")
|
|||
|
|
|
|||
|
|
if len(points) < 6:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
lines.append(f"{class_id} " + " ".join(points))
|
|||
|
|
|
|||
|
|
with open(txt_path, "w", encoding="utf-8") as f:
|
|||
|
|
f.write("\n".join(lines))
|
|||
|
|
|
|||
|
|
print("✅ CVAT segmentation → YOLO seg 转换完成")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
|
|||
|
|
cvat_to_yolo_seg(
|
|||
|
|
xml_path="annotations.xml",
|
|||
|
|
output_dir="labels_seg",
|
|||
|
|
force_class_id={
|
|||
|
|
"yemian": 0
|
|||
|
|
}
|
|||
|
|
)
|