142 lines
4.1 KiB
Python
142 lines
4.1 KiB
Python
import os
|
||
import xml.etree.ElementTree as ET
|
||
from pathlib import Path
|
||
|
||
|
||
def yolo_detect_to_cvat(
|
||
images_dir,
|
||
labels_dir,
|
||
output_xml,
|
||
class_id_to_name,
|
||
):
|
||
"""
|
||
将 YOLO Detect 格式转换为 CVAT XML(目标检测 box)
|
||
|
||
Args:
|
||
images_dir (str): 图片目录
|
||
labels_dir (str): YOLO txt 标注目录
|
||
output_xml (str): 输出 CVAT XML 路径
|
||
class_id_to_name (dict): {0: "xxx", 1: "yyy"}
|
||
"""
|
||
|
||
images_dir = Path(images_dir)
|
||
labels_dir = Path(labels_dir)
|
||
|
||
# ----------------------------
|
||
# 创建 XML 结构
|
||
# ----------------------------
|
||
annotations = ET.Element("annotations")
|
||
|
||
version = ET.SubElement(annotations, "version")
|
||
version.text = "1.1"
|
||
|
||
# labels
|
||
meta = ET.SubElement(annotations, "meta")
|
||
task = ET.SubElement(meta, "task")
|
||
labels_elem = ET.SubElement(task, "labels")
|
||
|
||
for class_id, name in class_id_to_name.items():
|
||
label = ET.SubElement(labels_elem, "label")
|
||
name_elem = ET.SubElement(label, "name")
|
||
name_elem.text = name
|
||
|
||
image_id = 0
|
||
|
||
# ----------------------------
|
||
# 遍历图片
|
||
# ----------------------------
|
||
for img_path in sorted(images_dir.iterdir()):
|
||
if img_path.suffix.lower() not in [".jpg", ".png", ".jpeg"]:
|
||
continue
|
||
|
||
import cv2
|
||
img = cv2.imread(str(img_path))
|
||
if img is None:
|
||
print(f"⚠️ 无法读取图片: {img_path}")
|
||
continue
|
||
|
||
height, width = img.shape[:2]
|
||
|
||
image_elem = ET.SubElement(
|
||
annotations,
|
||
"image",
|
||
{
|
||
"id": str(image_id),
|
||
"name": img_path.name,
|
||
"width": str(width),
|
||
"height": str(height),
|
||
}
|
||
)
|
||
|
||
label_txt = labels_dir / f"{img_path.stem}.txt"
|
||
|
||
# ----------------------------
|
||
# 有 / 无标签都要建 image
|
||
# ----------------------------
|
||
if label_txt.exists():
|
||
with open(label_txt, "r") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
class_id, xc, yc, w, h = map(float, line.split())
|
||
class_id = int(class_id)
|
||
|
||
label_name = class_id_to_name.get(class_id)
|
||
if label_name is None:
|
||
print(f"⚠️ 未知 class_id: {class_id}")
|
||
continue
|
||
|
||
# YOLO → CVAT 坐标
|
||
box_w = w * width
|
||
box_h = h * height
|
||
xtl = xc * width - box_w / 2
|
||
ytl = yc * height - box_h / 2
|
||
xbr = xtl + box_w
|
||
ybr = ytl + box_h
|
||
|
||
ET.SubElement(
|
||
image_elem,
|
||
"box",
|
||
{
|
||
"label": label_name,
|
||
"xtl": f"{xtl:.2f}",
|
||
"ytl": f"{ytl:.2f}",
|
||
"xbr": f"{xbr:.2f}",
|
||
"ybr": f"{ybr:.2f}",
|
||
"occluded": "0",
|
||
"source": "manual",
|
||
}
|
||
)
|
||
|
||
image_id += 1
|
||
print(f"✅ {img_path.name}")
|
||
|
||
# ----------------------------
|
||
# 写 XML
|
||
# ----------------------------
|
||
tree = ET.ElementTree(annotations)
|
||
tree.write(output_xml, encoding="utf-8", xml_declaration=True)
|
||
|
||
print(f"\n🎉 转换完成,CVAT XML 已生成:{output_xml}")
|
||
|
||
if __name__ == "__main__":
|
||
IMAGES_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/detect/1"
|
||
LABELS_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/detect/1"
|
||
OUTPUT_XML = "annotations.xml"
|
||
|
||
# ⚠️ 一定要和 YOLO 训练时一致
|
||
CLASS_ID_TO_NAME = {
|
||
#0: "hole",
|
||
#1: "crack"
|
||
0: "bag"
|
||
}
|
||
|
||
yolo_detect_to_cvat(
|
||
images_dir=IMAGES_DIR,
|
||
labels_dir=LABELS_DIR,
|
||
output_xml=OUTPUT_XML,
|
||
class_id_to_name=CLASS_ID_TO_NAME
|
||
)
|