Files
zjsh_yolov11/yolo11_obb/data_trans_dota.py
2025-09-01 14:14:18 +08:00

226 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import xml.etree.ElementTree as ET
import math
"""
obb转换代码使用
只用修改if __name__ == '__main__': 下面的三条路径即可:
roxml_path = r"path/已经存放着原始XML文件的文件夹名字"
dotaxml_path = r"path/准备要存放DOTA格式的XML文件的文件夹名字"
# (小建议:先手动创建该文件夹,然后路径放在这里)
out_path = r"path/准备要存放DOTA格式的TXT文件的文件夹名字"
# (小建议:先手动创建该文件夹,然后路径放在这里)
"""
def edit_xml(xml_file, dotaxml_file):
"""
修改 XML 文件,将 bndbox 或 robndbox 转换为四点坐标格式。
:param xml_file: 原始 XML 文件路径
:param dotaxml_file: 转换后的 XML 文件保存路径
"""
tree = ET.parse(xml_file)
objs = tree.findall('object')
for ix, obj in enumerate(objs):
x0 = ET.Element("x0") # 创建节点
y0 = ET.Element("y0")
x1 = ET.Element("x1")
y1 = ET.Element("y1")
x2 = ET.Element("x2")
y2 = ET.Element("y2")
x3 = ET.Element("x3")
y3 = ET.Element("y3")
if obj.find('robndbox') is None:
# 处理 bndbox
obj_bnd = obj.find('bndbox')
obj_xmin = obj_bnd.find('xmin')
obj_ymin = obj_bnd.find('ymin')
obj_xmax = obj_bnd.find('xmax')
obj_ymax = obj_bnd.find('ymax')
xmin = float(obj_xmin.text)
ymin = float(obj_ymin.text)
xmax = float(obj_xmax.text)
ymax = float(obj_ymax.text)
obj_bnd.remove(obj_xmin) # 删除节点
obj_bnd.remove(obj_ymin)
obj_bnd.remove(obj_xmax)
obj_bnd.remove(obj_ymax)
x0.text = str(xmin)
y0.text = str(ymax)
x1.text = str(xmax)
y1.text = str(ymax)
x2.text = str(xmax)
y2.text = str(ymin)
x3.text = str(xmin)
y3.text = str(ymin)
else:
# 处理 robndbox
obj_bnd = obj.find('robndbox')
obj_bnd.tag = 'bndbox' # 修改节点名
obj_cx = obj_bnd.find('cx')
obj_cy = obj_bnd.find('cy')
obj_w = obj_bnd.find('w')
obj_h = obj_bnd.find('h')
obj_angle = obj_bnd.find('angle')
cx = float(obj_cx.text)
cy = float(obj_cy.text)
w = float(obj_w.text)
h = float(obj_h.text)
angle = float(obj_angle.text)
obj_bnd.remove(obj_cx) # 删除节点
obj_bnd.remove(obj_cy)
obj_bnd.remove(obj_w)
obj_bnd.remove(obj_h)
obj_bnd.remove(obj_angle)
x0.text, y0.text = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
x1.text, y1.text = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
x2.text, y2.text = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
x3.text, y3.text = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
obj_bnd.append(x0) # 新增节点
obj_bnd.append(y0)
obj_bnd.append(x1)
obj_bnd.append(y1)
obj_bnd.append(x2)
obj_bnd.append(y2)
obj_bnd.append(x3)
obj_bnd.append(y3)
tree.write(dotaxml_file, method='xml', encoding='utf-8') # 更新 XML 文件
def rotatePoint(xc, yc, xp, yp, theta):
"""
计算旋转后的点坐标。
:param xc: 旋转中心 x 坐标
:param yc: 旋转中心 y 坐标
:param xp: 点 x 坐标
:param yp: 点 y 坐标
:param theta: 旋转角度(弧度)
:return: 旋转后的点坐标 (x, y)
"""
xoff = xp - xc
yoff = yp - yc
cosTheta = math.cos(theta)
sinTheta = math.sin(theta)
pResx = cosTheta * xoff + sinTheta * yoff
pResy = -sinTheta * xoff + cosTheta * yoff
return str(int(xc + pResx)), str(int(yc + pResy))
def get_unique_classes(xml_path):
"""
从 XML 文件中提取所有唯一的类别名称。
:param xml_path: XML 文件所在的目录
:return: 包含所有唯一类别名称的集合
"""
unique_classes = set() # 使用集合存储唯一的类别名称
files = os.listdir(xml_path)
for file in files:
if not file.endswith('.xml'):
continue
tree = ET.parse(os.path.join(xml_path, file))
root = tree.getroot()
for obj in root.findall('object'):
cls = obj.find('name').text
unique_classes.add(cls)
return unique_classes
def generate_class_mapping(unique_classes):
"""
根据唯一的类别名称生成类别映射字典。
:param unique_classes: 包含所有唯一类别名称的集合
:return: 类别名称到编号的映射字典
"""
class_mapping = {}
for index, cls in enumerate(sorted(unique_classes)): # 按字母顺序排序
class_mapping[cls] = index
return class_mapping
def totxt(xml_path, out_path, class_mapping):
"""
将 XML 文件转换为 TXT 文件,并根据类别名称动态生成编号。
:param xml_path: XML 文件所在的目录
:param out_path: 保存 TXT 文件的目录
:param class_mapping: 类别名称到编号的映射字典
"""
# 确保输出目录存在
if not os.path.exists(out_path):
os.makedirs(out_path)
# 遍历 xml_path 下的所有文件
files = os.listdir(xml_path)
for file in files:
# 只处理 .xml 文件
if not file.endswith('.xml'):
continue
# 解析 XML 文件
tree = ET.parse(os.path.join(xml_path, file))
root = tree.getroot()
# 获取文件名(不带扩展名)
name = os.path.splitext(file)[0]
# 设置输出文件路径
output = os.path.join(out_path, name + '.txt') # 使用 os.path.join 确保路径正确
with open(output, 'w') as f: # 使用 with 打开文件,确保文件正确关闭
objs = tree.findall('object')
for obj in objs:
cls = obj.find('name').text
box = obj.find('bndbox')
x0 = int(float(box.find('x0').text))
y0 = int(float(box.find('y0').text))
x1 = int(float(box.find('x1').text))
y1 = int(float(box.find('y1').text))
x2 = int(float(box.find('x2').text))
y2 = int(float(box.find('y2').text))
x3 = int(float(box.find('x3').text))
y3 = int(float(box.find('y3').text))
# 根据类别名称获取对应的编号
cls_index = class_mapping.get(cls, -1) # 如果类别不存在,默认返回 -1
if cls_index == -1:
print(f"Warning: Class '{cls}' not found in class_mapping. Skipping this object.")
continue
# 写入文件
f.write("{} {} {} {} {} {} {} {} {} {}\n".format(x0, y0, x1, y1, x2, y2, x3, y3, cls, cls_index))
print(f"Generated: {output}")
if __name__ == '__main__':
# 设置路径
roxml_path = r"/home/hx/桌面/image/image/1" # 存放原始 XML 文件夹名字的路径
dotaxml_path = r"/home/hx/桌面/image/image/1" # 存放 DOTA格式的 XML文件夹名字小建议先手动创建该文件夹然后路径放在这里
out_path = r"/home/hx/桌面/image/image/2" # 存放 DOTA格式的 TXT 文件夹名字(小建议:先手动创建该文件夹,然后路径放在这里)
# 第一步:将 XML 文件统一转换成旋转框的 XML 文件
filelist = os.listdir(roxml_path)
for file in filelist:
print(f"Processing: {os.path.join(roxml_path, file)}")
edit_xml(os.path.join(roxml_path, file), os.path.join(dotaxml_path, file))
# 第二步:从 XML 文件中提取所有唯一的类别名称
unique_classes = get_unique_classes(dotaxml_path)
print(f"Unique classes found: {unique_classes}")
# 第三步:生成类别映射字典
class_mapping = generate_class_mapping(unique_classes)
print(f"Generated class mapping: {class_mapping}")
# 第四步:将旋转框 XML 文件转换成 TXT 格式
totxt(dotaxml_path, out_path, class_mapping)