100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
import os
|
||
import xml.etree.ElementTree as ET
|
||
from pathlib import Path
|
||
|
||
|
||
def cvat_to_yolo_detect(xml_path, output_dir, class_name_to_id=None):
|
||
"""
|
||
将 CVAT 导出的 XML(目标检测模式)转换为 YOLO Detect 格式
|
||
|
||
Args:
|
||
xml_path (str): CVAT 导出的 XML 文件路径
|
||
output_dir (str): 输出 .txt 标注文件的目录
|
||
class_name_to_id (dict, optional): 类别名到 ID 的映射。
|
||
如果为 None,则自动从 XML 的 <labels> 中按顺序分配(0,1,2...)
|
||
"""
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
tree = ET.parse(xml_path)
|
||
root = tree.getroot()
|
||
|
||
# 自动提取类别顺序(如果未提供映射)
|
||
if class_name_to_id is None:
|
||
class_name_to_id = {}
|
||
labels_elem = root.find(".//labels")
|
||
if labels_elem is not None:
|
||
for idx, label in enumerate(labels_elem.findall("label")):
|
||
name = label.find("name").text
|
||
class_name_to_id[name] = idx
|
||
else:
|
||
print("⚠️ 未找到 <labels>,请手动提供 class_name_to_id")
|
||
|
||
print(f"类别映射: {class_name_to_id}")
|
||
|
||
# 遍历所有 <image>
|
||
for image in root.findall("image"):
|
||
img_name = image.get("name")
|
||
width = int(image.get("width"))
|
||
height = int(image.get("height"))
|
||
|
||
# 构建 .txt 文件名(去掉扩展名)
|
||
stem = Path(img_name).stem
|
||
txt_path = output_dir / f"{stem}.txt"
|
||
|
||
boxes = []
|
||
for box in image.findall("box"):
|
||
label = box.get("label")
|
||
if label not in class_name_to_id:
|
||
print(f"⚠️ 未知类别 '{label}',跳过(图片: {img_name})")
|
||
continue
|
||
|
||
class_id = class_name_to_id[label]
|
||
xtl = float(box.get("xtl"))
|
||
ytl = float(box.get("ytl"))
|
||
xbr = float(box.get("xbr"))
|
||
ybr = float(box.get("ybr"))
|
||
|
||
# 转为 YOLO 格式(归一化)
|
||
x_center = (xtl + xbr) / 2 / width
|
||
y_center = (ytl + ybr) / 2 / height
|
||
w = (xbr - xtl) / width
|
||
h = (ybr - ytl) / height
|
||
|
||
# 限制在 [0,1](防止因标注误差越界)
|
||
x_center = max(0.0, min(1.0, x_center))
|
||
y_center = max(0.0, min(1.0, y_center))
|
||
w = max(0.0, min(1.0, w))
|
||
h = max(0.0, min(1.0, h))
|
||
|
||
boxes.append(f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
|
||
|
||
# 写入 .txt 文件(即使无框也创建空文件)
|
||
with open(txt_path, "w") as f:
|
||
f.write("\n".join(boxes))
|
||
|
||
print(f"✅ {img_name} → {len(boxes)} 个目标")
|
||
|
||
print(f"\n🎉 转换完成!YOLO 标注已保存至: {output_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# ====== 配置区 ======
|
||
XML_PATH = "annotations.xml" # 替换为你的 CVAT XML 路径
|
||
OUTPUT_LABELS_DIR = "labels" # 输出的 YOLO .txt 目录
|
||
|
||
# 方式1:自动从 XML 提取类别(推荐)
|
||
#CLASS_MAP = None
|
||
|
||
# 方式2:手动指定(确保与训练时一致)
|
||
CLASS_MAP = {
|
||
"bag": 0,
|
||
"bag35": 1,
|
||
}
|
||
|
||
# ====== 执行转换 ======
|
||
cvat_to_yolo_detect(
|
||
xml_path=XML_PATH,
|
||
output_dir=OUTPUT_LABELS_DIR,
|
||
class_name_to_id=CLASS_MAP
|
||
) |