最新推送
This commit is contained in:
106
yolo11_seg/cvattoseg.py
Normal file
106
yolo11_seg/cvattoseg.py
Normal file
@ -0,0 +1,106 @@
|
||||
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def cvat_to_yolo_seg(
|
||||
xml_path,
|
||||
output_dir,
|
||||
class_name_to_id=None,
|
||||
force_class_id=None
|
||||
):
|
||||
"""
|
||||
将 CVAT 导出的 XML(polygon / segmentation)转换为 YOLO Segmentation 格式
|
||||
|
||||
Args:
|
||||
xml_path (str): CVAT 导出的 XML 文件路径
|
||||
output_dir (str): 输出 .txt 标注文件的目录
|
||||
class_name_to_id (dict, optional):
|
||||
类别名到 ID 的映射
|
||||
force_class_id (dict, optional):
|
||||
强制指定某些类别的 ID,例如 {"yemian": 0}
|
||||
"""
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# ----------------------------
|
||||
# 自动提取类别映射
|
||||
# ----------------------------
|
||||
if class_name_to_id is None:
|
||||
class_name_to_id = {}
|
||||
labels_elem = root.find(".//labels")
|
||||
if labels_elem is not None:
|
||||
for idx, label in enumerate(labels_elem.findall("label")):
|
||||
name = label.find("name").text
|
||||
class_name_to_id[name] = idx
|
||||
else:
|
||||
raise RuntimeError("❌ 未找到 <labels>,请手动提供 class_name_to_id")
|
||||
|
||||
print(f"原始类别映射: {class_name_to_id}")
|
||||
|
||||
# ----------------------------
|
||||
# 强制修改类别 ID(新增功能)
|
||||
# ----------------------------
|
||||
if force_class_id:
|
||||
for name, new_id in force_class_id.items():
|
||||
if name in class_name_to_id:
|
||||
old_id = class_name_to_id[name]
|
||||
class_name_to_id[name] = new_id
|
||||
print(f"强制修改类别映射: {name} {old_id} → {new_id}")
|
||||
else:
|
||||
print(f"类别 {name} 不存在,跳过")
|
||||
|
||||
print(f"最终类别映射: {class_name_to_id}")
|
||||
|
||||
# ----------------------------
|
||||
# 遍历每一张 image
|
||||
# ----------------------------
|
||||
for image in root.findall("image"):
|
||||
img_name = image.get("name")
|
||||
width = int(image.get("width"))
|
||||
height = int(image.get("height"))
|
||||
|
||||
txt_path = output_dir / (Path(img_name).stem + ".txt")
|
||||
lines = []
|
||||
|
||||
for polygon in image.findall("polygon"):
|
||||
label = polygon.get("label")
|
||||
if label not in class_name_to_id:
|
||||
continue
|
||||
|
||||
class_id = class_name_to_id[label]
|
||||
points_str = polygon.get("points")
|
||||
|
||||
points = []
|
||||
for p in points_str.strip().split(";"):
|
||||
x, y = p.split(",")
|
||||
x = float(x) / width
|
||||
y = float(y) / height
|
||||
points.append(f"{x:.6f}")
|
||||
points.append(f"{y:.6f}")
|
||||
|
||||
if len(points) < 6:
|
||||
continue
|
||||
|
||||
lines.append(f"{class_id} " + " ".join(points))
|
||||
|
||||
with open(txt_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
print("✅ CVAT segmentation → YOLO seg 转换完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
cvat_to_yolo_seg(
|
||||
xml_path="annotations.xml",
|
||||
output_dir="labels_seg",
|
||||
force_class_id={
|
||||
"yemian": 0
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user