Files
zjsh_yolov11/yolo11_obb/trans_cvattoobb.py

128 lines
4.0 KiB
Python
Raw Normal View History

2025-09-05 14:29:33 +08:00
# cvat_xml_to_yolo_obb.py
# 仅在有标注时生成 .txt 文件
import xml.etree.ElementTree as ET
import numpy as np
from pathlib import Path
def rotate_box(x_center, y_center, w, h, angle_deg):
"""
将旋转框转为 4 个角点坐标未归一化
"""
angle_rad = np.radians(angle_deg)
# 四个角点相对于中心的偏移
corners = np.array([
[-w/2, -h/2],
[ w/2, -h/2],
[ w/2, h/2],
[-w/2, h/2]
])
# 旋转矩阵(顺时针)
cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
rotation_matrix = np.array([[cos_a, -sin_a],
[sin_a, cos_a]])
# 旋转并平移
rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x_center, y_center])
return rotated_corners # 返回 (4, 2) 数组
def cvat_xml_to_yolo_obb(cvat_xml_path, output_dir, class_name_to_id=None):
"""
CVAT annotations.xml 转为 YOLO-OBB 格式
- 仅在有有效标注时创建 .txt 文件
"""
if class_name_to_id is None:
class_name_to_id = {"clamp": 0} # ✅ 请根据你的实际类别修改
tree = ET.parse(cvat_xml_path)
root = tree.getroot()
# 创建 labels 输出目录
labels_dir = Path(output_dir) / "labels"
labels_dir.mkdir(parents=True, exist_ok=True)
# 统计信息
processed_images = 0
saved_files = 0
skipped_classes = 0
for image_elem in root.findall("image"):
image_name = image_elem.get("name")
img_w = float(image_elem.get("width"))
img_h = float(image_elem.get("height"))
label_file = (labels_dir / Path(image_name).stem).with_suffix(".txt")
boxes = image_elem.findall("box")
valid_annotations = []
for box in boxes:
label = box.get("label")
if label not in class_name_to_id:
print(f"⚠️ 跳过未知类别: {label} (图片: {image_name})")
skipped_classes += 1
continue
class_id = class_name_to_id[label]
xtl = float(box.get("xtl"))
ytl = float(box.get("ytl"))
xbr = float(box.get("xbr"))
ybr = float(box.get("ybr"))
# 计算中心点和宽高
x_center = (xtl + xbr) / 2
y_center = (ytl + ybr) / 2
w = xbr - xtl
h = ybr - ytl
# 获取旋转角度
angle = float(box.get("rotation", 0.0))
# 计算 4 个角点
corners = rotate_box(x_center, y_center, w, h, angle)
# 归一化到 [0,1]
corners[:, 0] /= img_w
corners[:, 1] /= img_h
# 展平并生成行
points = corners.flatten()
line = str(class_id) + " " + " ".join(f"{coord:.6f}" for coord in points)
valid_annotations.append(line)
# ✅ 只有存在有效标注时才写入文件
if valid_annotations:
with open(label_file, 'w') as f:
f.write("\n".join(valid_annotations) + "\n")
saved_files += 1
print(f"✅ 已生成: {label_file}")
# else: # 无标注,不创建文件
# print(f"🟡 无标注,跳过: {image_name}")
processed_images += 1
print(f"\n🎉 转换完成!")
print(f"📊 处理图片: {processed_images}")
print(f"✅ 生成标签: {saved_files}")
if skipped_classes:
print(f"⚠️ 跳过类别: {skipped_classes} 个标注")
# ==================== 使用示例 ====================
if __name__ == "__main__":
# ✅ 请修改以下路径和类别
CVAT_XML_PATH = "annotations.xml" # 你的 annotations.xml 文件
OUTPUT_DIR = "yolo_obb_dataset" # 输出目录
CLASS_MAPPING = {
"clamp": 0, # 请根据你的实际类别修改
# "other_class": 1,
}
# 执行转换
cvat_xml_to_yolo_obb(
cvat_xml_path=CVAT_XML_PATH,
output_dir=OUTPUT_DIR,
class_name_to_id=CLASS_MAPPING
)