first commit
This commit is contained in:
116
ailai_pc/trans_cvattoobb.py
Normal file
116
ailai_pc/trans_cvattoobb.py
Normal file
@ -0,0 +1,116 @@
|
||||
# pascal_robndbox_to_yolo_obb.py
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
def robndbox_to_yolo_obb(xml_path, output_dir, class_names):
|
||||
"""
|
||||
将单个带有 <robndbox> 的 Pascal VOC XML 转换为 YOLO-OBB 格式 .txt
|
||||
"""
|
||||
try:
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# 获取图像尺寸
|
||||
width_elem = root.find("size/width")
|
||||
height_elem = root.find("size/height")
|
||||
if width_elem is None or height_elem is None:
|
||||
print(f"❌ 跳过 {xml_path}: 缺少 size/width 或 size/height")
|
||||
return
|
||||
img_w = int(width_elem.text)
|
||||
img_h = int(height_elem.text)
|
||||
|
||||
if img_w == 0 or img_h == 0:
|
||||
print(f"❌ 跳过 {xml_path}: 图像尺寸为 0")
|
||||
return
|
||||
|
||||
# 输出文件路径
|
||||
label_file = Path(output_dir) / "labels" / (Path(xml_path).stem + ".txt")
|
||||
label_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
lines = []
|
||||
for obj in root.findall("object"):
|
||||
name = obj.find("name").text
|
||||
if name not in class_names:
|
||||
print(f"⚠️ 跳过未知类别: {name} (文件: {xml_path.name})")
|
||||
continue
|
||||
class_id = class_names.index(name)
|
||||
|
||||
rb = obj.find("robndbox")
|
||||
if rb is None:
|
||||
print(f"⚠️ 跳过无 robndbox 的对象: {name}")
|
||||
continue
|
||||
|
||||
cx = float(rb.find("cx").text)
|
||||
cy = float(rb.find("cy").text)
|
||||
w = float(rb.find("w").text)
|
||||
h = float(rb.find("h").text)
|
||||
angle_deg = float(rb.find("angle").text)
|
||||
|
||||
# 计算四个角点(相对于中心旋转)
|
||||
angle_rad = np.radians(angle_deg)
|
||||
cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
|
||||
|
||||
corners = np.array([
|
||||
[-w/2, -h/2],
|
||||
[ w/2, -h/2],
|
||||
[ w/2, h/2],
|
||||
[-w/2, h/2]
|
||||
])
|
||||
rotation_matrix = np.array([[cos_a, -sin_a], [sin_a, cos_a]])
|
||||
rotated_corners = np.dot(corners, rotation_matrix.T) + [cx, cy]
|
||||
|
||||
# 归一化到 [0,1]
|
||||
rotated_corners[:, 0] /= img_w
|
||||
rotated_corners[:, 1] /= img_h
|
||||
|
||||
# 展平并生成 YOLO-OBB 行
|
||||
coords = rotated_corners.flatten()
|
||||
line = str(class_id) + " " + " ".join(f"{x:.6f}" for x in coords)
|
||||
lines.append(line)
|
||||
|
||||
# 只有存在有效标注才写入文件
|
||||
if lines:
|
||||
with open(label_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines) + "\n")
|
||||
print(f"✅ 已生成: {label_file}")
|
||||
else:
|
||||
print(f"🟡 无有效标注,跳过生成: {label_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 处理 {xml_path} 时出错: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
# ==================== 配置区 ====================
|
||||
# ✅ 修改以下路径和类别
|
||||
XML_DIR = "/home/hx/桌面/ailai_test/train" # 包含 .xml 文件的目录
|
||||
OUTPUT_DIR = "yolo_obb_dataset" # 输出目录
|
||||
CLASS_NAMES = ["ban", "bag"] # 你的类别列表,顺序即 class_id
|
||||
# ==============================================
|
||||
|
||||
xml_dir = Path(XML_DIR)
|
||||
output_dir = Path(OUTPUT_DIR)
|
||||
|
||||
if not xml_dir.exists():
|
||||
raise FileNotFoundError(f"未找到 XML 目录: {xml_dir}")
|
||||
|
||||
# 查找所有 .xml 文件
|
||||
xml_files = list(xml_dir.glob("*.xml"))
|
||||
if not xml_files:
|
||||
print(f"⚠️ 在 {xml_dir} 中未找到 .xml 文件")
|
||||
return
|
||||
|
||||
print(f"🔍 找到 {len(xml_files)} 个 XML 文件")
|
||||
print(f"📦 类别映射: { {i: name for i, name in enumerate(CLASS_NAMES)} }")
|
||||
|
||||
# 批量转换
|
||||
for xml_file in xml_files:
|
||||
robndbox_to_yolo_obb(xml_file, output_dir, CLASS_NAMES)
|
||||
|
||||
print(f"\n🎉 转换完成!标签已保存至: {output_dir / 'labels'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user