116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
# pascal_robndbox_to_yolo_obb.py
|
|
import xml.etree.ElementTree as ET
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import argparse
|
|
|
|
def robndbox_to_yolo_obb(xml_path, output_dir, class_names):
|
|
"""
|
|
将单个带有 <robndbox> 的 Pascal VOC XML 转换为 YOLO-OBB 格式 .txt
|
|
"""
|
|
try:
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
|
|
# 获取图像尺寸
|
|
width_elem = root.find("size/width")
|
|
height_elem = root.find("size/height")
|
|
if width_elem is None or height_elem is None:
|
|
print(f"❌ 跳过 {xml_path}: 缺少 size/width 或 size/height")
|
|
return
|
|
img_w = int(width_elem.text)
|
|
img_h = int(height_elem.text)
|
|
|
|
if img_w == 0 or img_h == 0:
|
|
print(f"❌ 跳过 {xml_path}: 图像尺寸为 0")
|
|
return
|
|
|
|
# 输出文件路径
|
|
label_file = Path(output_dir) / "labels" / (Path(xml_path).stem + ".txt")
|
|
label_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
lines = []
|
|
for obj in root.findall("object"):
|
|
name = obj.find("name").text
|
|
if name not in class_names:
|
|
print(f"⚠️ 跳过未知类别: {name} (文件: {xml_path.name})")
|
|
continue
|
|
class_id = class_names.index(name)
|
|
|
|
rb = obj.find("robndbox")
|
|
if rb is None:
|
|
print(f"⚠️ 跳过无 robndbox 的对象: {name}")
|
|
continue
|
|
|
|
cx = float(rb.find("cx").text)
|
|
cy = float(rb.find("cy").text)
|
|
w = float(rb.find("w").text)
|
|
h = float(rb.find("h").text)
|
|
angle_deg = float(rb.find("angle").text)
|
|
|
|
# 计算四个角点(相对于中心旋转)
|
|
angle_rad = np.radians(angle_deg)
|
|
cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
|
|
|
|
corners = np.array([
|
|
[-w/2, -h/2],
|
|
[ w/2, -h/2],
|
|
[ w/2, h/2],
|
|
[-w/2, h/2]
|
|
])
|
|
rotation_matrix = np.array([[cos_a, -sin_a], [sin_a, cos_a]])
|
|
rotated_corners = np.dot(corners, rotation_matrix.T) + [cx, cy]
|
|
|
|
# 归一化到 [0,1]
|
|
rotated_corners[:, 0] /= img_w
|
|
rotated_corners[:, 1] /= img_h
|
|
|
|
# 展平并生成 YOLO-OBB 行
|
|
coords = rotated_corners.flatten()
|
|
line = str(class_id) + " " + " ".join(f"{x:.6f}" for x in coords)
|
|
lines.append(line)
|
|
|
|
# 只有存在有效标注才写入文件
|
|
if lines:
|
|
with open(label_file, "w", encoding="utf-8") as f:
|
|
f.write("\n".join(lines) + "\n")
|
|
print(f"✅ 已生成: {label_file}")
|
|
else:
|
|
print(f"🟡 无有效标注,跳过生成: {label_file}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 处理 {xml_path} 时出错: {e}")
|
|
|
|
|
|
def main():
|
|
# ==================== 配置区 ====================
|
|
# ✅ 修改以下路径和类别
|
|
XML_DIR = "/home/hx/开发/ailai_image_obb/ailai_pc/annotations.xml" # 包含 .xml 文件的目录
|
|
OUTPUT_DIR = "yolo_obb_dataset" # 输出目录
|
|
CLASS_NAMES = ["ban", "bag"] # 你的类别列表,顺序即 class_id
|
|
# ==============================================
|
|
|
|
xml_dir = Path(XML_DIR)
|
|
output_dir = Path(OUTPUT_DIR)
|
|
|
|
if not xml_dir.exists():
|
|
raise FileNotFoundError(f"未找到 XML 目录: {xml_dir}")
|
|
|
|
# 查找所有 .xml 文件
|
|
xml_files = list(xml_dir.glob("*.xml"))
|
|
if not xml_files:
|
|
print(f"⚠️ 在 {xml_dir} 中未找到 .xml 文件")
|
|
return
|
|
|
|
print(f"🔍 找到 {len(xml_files)} 个 XML 文件")
|
|
print(f"📦 类别映射: { {i: name for i, name in enumerate(CLASS_NAMES)} }")
|
|
|
|
# 批量转换
|
|
for xml_file in xml_files:
|
|
robndbox_to_yolo_obb(xml_file, output_dir, CLASS_NAMES)
|
|
|
|
print(f"\n🎉 转换完成!标签已保存至: {output_dir / 'labels'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |