226 lines
8.1 KiB
Python
226 lines
8.1 KiB
Python
|
|
import os
|
|||
|
|
import xml.etree.ElementTree as ET
|
|||
|
|
import math
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
obb转换代码使用:
|
|||
|
|
只用修改if __name__ == '__main__': 下面的三条路径即可:
|
|||
|
|
|
|||
|
|
roxml_path = r"path/已经存放着原始XML文件的文件夹名字"
|
|||
|
|
dotaxml_path = r"path/准备要存放DOTA格式的XML文件的文件夹名字"
|
|||
|
|
# (小建议:先手动创建该文件夹,然后路径放在这里)
|
|||
|
|
out_path = r"path/准备要存放DOTA格式的TXT文件的文件夹名字"
|
|||
|
|
# (小建议:先手动创建该文件夹,然后路径放在这里)
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def edit_xml(xml_file, dotaxml_file):
|
|||
|
|
"""
|
|||
|
|
修改 XML 文件,将 bndbox 或 robndbox 转换为四点坐标格式。
|
|||
|
|
|
|||
|
|
:param xml_file: 原始 XML 文件路径
|
|||
|
|
:param dotaxml_file: 转换后的 XML 文件保存路径
|
|||
|
|
"""
|
|||
|
|
tree = ET.parse(xml_file)
|
|||
|
|
objs = tree.findall('object')
|
|||
|
|
for ix, obj in enumerate(objs):
|
|||
|
|
x0 = ET.Element("x0") # 创建节点
|
|||
|
|
y0 = ET.Element("y0")
|
|||
|
|
x1 = ET.Element("x1")
|
|||
|
|
y1 = ET.Element("y1")
|
|||
|
|
x2 = ET.Element("x2")
|
|||
|
|
y2 = ET.Element("y2")
|
|||
|
|
x3 = ET.Element("x3")
|
|||
|
|
y3 = ET.Element("y3")
|
|||
|
|
|
|||
|
|
if obj.find('robndbox') is None:
|
|||
|
|
# 处理 bndbox
|
|||
|
|
obj_bnd = obj.find('bndbox')
|
|||
|
|
obj_xmin = obj_bnd.find('xmin')
|
|||
|
|
obj_ymin = obj_bnd.find('ymin')
|
|||
|
|
obj_xmax = obj_bnd.find('xmax')
|
|||
|
|
obj_ymax = obj_bnd.find('ymax')
|
|||
|
|
xmin = float(obj_xmin.text)
|
|||
|
|
ymin = float(obj_ymin.text)
|
|||
|
|
xmax = float(obj_xmax.text)
|
|||
|
|
ymax = float(obj_ymax.text)
|
|||
|
|
obj_bnd.remove(obj_xmin) # 删除节点
|
|||
|
|
obj_bnd.remove(obj_ymin)
|
|||
|
|
obj_bnd.remove(obj_xmax)
|
|||
|
|
obj_bnd.remove(obj_ymax)
|
|||
|
|
x0.text = str(xmin)
|
|||
|
|
y0.text = str(ymax)
|
|||
|
|
x1.text = str(xmax)
|
|||
|
|
y1.text = str(ymax)
|
|||
|
|
x2.text = str(xmax)
|
|||
|
|
y2.text = str(ymin)
|
|||
|
|
x3.text = str(xmin)
|
|||
|
|
y3.text = str(ymin)
|
|||
|
|
else:
|
|||
|
|
# 处理 robndbox
|
|||
|
|
obj_bnd = obj.find('robndbox')
|
|||
|
|
obj_bnd.tag = 'bndbox' # 修改节点名
|
|||
|
|
obj_cx = obj_bnd.find('cx')
|
|||
|
|
obj_cy = obj_bnd.find('cy')
|
|||
|
|
obj_w = obj_bnd.find('w')
|
|||
|
|
obj_h = obj_bnd.find('h')
|
|||
|
|
obj_angle = obj_bnd.find('angle')
|
|||
|
|
cx = float(obj_cx.text)
|
|||
|
|
cy = float(obj_cy.text)
|
|||
|
|
w = float(obj_w.text)
|
|||
|
|
h = float(obj_h.text)
|
|||
|
|
angle = float(obj_angle.text)
|
|||
|
|
obj_bnd.remove(obj_cx) # 删除节点
|
|||
|
|
obj_bnd.remove(obj_cy)
|
|||
|
|
obj_bnd.remove(obj_w)
|
|||
|
|
obj_bnd.remove(obj_h)
|
|||
|
|
obj_bnd.remove(obj_angle)
|
|||
|
|
|
|||
|
|
x0.text, y0.text = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
|
|||
|
|
x1.text, y1.text = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
|
|||
|
|
x2.text, y2.text = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
|
|||
|
|
x3.text, y3.text = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
|
|||
|
|
|
|||
|
|
obj_bnd.append(x0) # 新增节点
|
|||
|
|
obj_bnd.append(y0)
|
|||
|
|
obj_bnd.append(x1)
|
|||
|
|
obj_bnd.append(y1)
|
|||
|
|
obj_bnd.append(x2)
|
|||
|
|
obj_bnd.append(y2)
|
|||
|
|
obj_bnd.append(x3)
|
|||
|
|
obj_bnd.append(y3)
|
|||
|
|
|
|||
|
|
tree.write(dotaxml_file, method='xml', encoding='utf-8') # 更新 XML 文件
|
|||
|
|
|
|||
|
|
|
|||
|
|
def rotatePoint(xc, yc, xp, yp, theta):
|
|||
|
|
"""
|
|||
|
|
计算旋转后的点坐标。
|
|||
|
|
|
|||
|
|
:param xc: 旋转中心 x 坐标
|
|||
|
|
:param yc: 旋转中心 y 坐标
|
|||
|
|
:param xp: 点 x 坐标
|
|||
|
|
:param yp: 点 y 坐标
|
|||
|
|
:param theta: 旋转角度(弧度)
|
|||
|
|
:return: 旋转后的点坐标 (x, y)
|
|||
|
|
"""
|
|||
|
|
xoff = xp - xc
|
|||
|
|
yoff = yp - yc
|
|||
|
|
cosTheta = math.cos(theta)
|
|||
|
|
sinTheta = math.sin(theta)
|
|||
|
|
pResx = cosTheta * xoff + sinTheta * yoff
|
|||
|
|
pResy = -sinTheta * xoff + cosTheta * yoff
|
|||
|
|
return str(int(xc + pResx)), str(int(yc + pResy))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_unique_classes(xml_path):
|
|||
|
|
"""
|
|||
|
|
从 XML 文件中提取所有唯一的类别名称。
|
|||
|
|
|
|||
|
|
:param xml_path: XML 文件所在的目录
|
|||
|
|
:return: 包含所有唯一类别名称的集合
|
|||
|
|
"""
|
|||
|
|
unique_classes = set() # 使用集合存储唯一的类别名称
|
|||
|
|
files = os.listdir(xml_path)
|
|||
|
|
for file in files:
|
|||
|
|
if not file.endswith('.xml'):
|
|||
|
|
continue
|
|||
|
|
tree = ET.parse(os.path.join(xml_path, file))
|
|||
|
|
root = tree.getroot()
|
|||
|
|
for obj in root.findall('object'):
|
|||
|
|
cls = obj.find('name').text
|
|||
|
|
unique_classes.add(cls)
|
|||
|
|
return unique_classes
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_class_mapping(unique_classes):
|
|||
|
|
"""
|
|||
|
|
根据唯一的类别名称生成类别映射字典。
|
|||
|
|
|
|||
|
|
:param unique_classes: 包含所有唯一类别名称的集合
|
|||
|
|
:return: 类别名称到编号的映射字典
|
|||
|
|
"""
|
|||
|
|
class_mapping = {}
|
|||
|
|
for index, cls in enumerate(sorted(unique_classes)): # 按字母顺序排序
|
|||
|
|
class_mapping[cls] = index
|
|||
|
|
return class_mapping
|
|||
|
|
|
|||
|
|
|
|||
|
|
def totxt(xml_path, out_path, class_mapping):
|
|||
|
|
"""
|
|||
|
|
将 XML 文件转换为 TXT 文件,并根据类别名称动态生成编号。
|
|||
|
|
|
|||
|
|
:param xml_path: XML 文件所在的目录
|
|||
|
|
:param out_path: 保存 TXT 文件的目录
|
|||
|
|
:param class_mapping: 类别名称到编号的映射字典
|
|||
|
|
"""
|
|||
|
|
# 确保输出目录存在
|
|||
|
|
if not os.path.exists(out_path):
|
|||
|
|
os.makedirs(out_path)
|
|||
|
|
|
|||
|
|
# 遍历 xml_path 下的所有文件
|
|||
|
|
files = os.listdir(xml_path)
|
|||
|
|
for file in files:
|
|||
|
|
# 只处理 .xml 文件
|
|||
|
|
if not file.endswith('.xml'):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 解析 XML 文件
|
|||
|
|
tree = ET.parse(os.path.join(xml_path, file))
|
|||
|
|
root = tree.getroot()
|
|||
|
|
|
|||
|
|
# 获取文件名(不带扩展名)
|
|||
|
|
name = os.path.splitext(file)[0]
|
|||
|
|
|
|||
|
|
# 设置输出文件路径
|
|||
|
|
output = os.path.join(out_path, name + '.txt') # 使用 os.path.join 确保路径正确
|
|||
|
|
with open(output, 'w') as f: # 使用 with 打开文件,确保文件正确关闭
|
|||
|
|
objs = tree.findall('object')
|
|||
|
|
for obj in objs:
|
|||
|
|
cls = obj.find('name').text
|
|||
|
|
box = obj.find('bndbox')
|
|||
|
|
x0 = int(float(box.find('x0').text))
|
|||
|
|
y0 = int(float(box.find('y0').text))
|
|||
|
|
x1 = int(float(box.find('x1').text))
|
|||
|
|
y1 = int(float(box.find('y1').text))
|
|||
|
|
x2 = int(float(box.find('x2').text))
|
|||
|
|
y2 = int(float(box.find('y2').text))
|
|||
|
|
x3 = int(float(box.find('x3').text))
|
|||
|
|
y3 = int(float(box.find('y3').text))
|
|||
|
|
|
|||
|
|
# 根据类别名称获取对应的编号
|
|||
|
|
cls_index = class_mapping.get(cls, -1) # 如果类别不存在,默认返回 -1
|
|||
|
|
if cls_index == -1:
|
|||
|
|
print(f"Warning: Class '{cls}' not found in class_mapping. Skipping this object.")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 写入文件
|
|||
|
|
f.write("{} {} {} {} {} {} {} {} {} {}\n".format(x0, y0, x1, y1, x2, y2, x3, y3, cls, cls_index))
|
|||
|
|
print(f"Generated: {output}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
# 设置路径
|
|||
|
|
roxml_path = r"/home/hx/桌面/image/1" # 存放原始 XML 文件夹名字的路径
|
|||
|
|
dotaxml_path = r"/home/hx/桌面/image/2" # 存放 DOTA格式的 XML文件夹名字(小建议:先手动创建该文件夹,然后路径放在这里)
|
|||
|
|
out_path = r"/home/hx/桌面/image/2" # 存放 DOTA格式的 TXT 文件夹名字(小建议:先手动创建该文件夹,然后路径放在这里)
|
|||
|
|
|
|||
|
|
# 第一步:将 XML 文件统一转换成旋转框的 XML 文件
|
|||
|
|
filelist = os.listdir(roxml_path)
|
|||
|
|
for file in filelist:
|
|||
|
|
print(f"Processing: {os.path.join(roxml_path, file)}")
|
|||
|
|
edit_xml(os.path.join(roxml_path, file), os.path.join(dotaxml_path, file))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 第二步:从 XML 文件中提取所有唯一的类别名称
|
|||
|
|
unique_classes = get_unique_classes(dotaxml_path)
|
|||
|
|
print(f"Unique classes found: {unique_classes}")
|
|||
|
|
|
|||
|
|
# 第三步:生成类别映射字典
|
|||
|
|
class_mapping = generate_class_mapping(unique_classes)
|
|||
|
|
print(f"Generated class mapping: {class_mapping}")
|
|||
|
|
|
|||
|
|
# 第四步:将旋转框 XML 文件转换成 TXT 格式
|
|||
|
|
totxt(dotaxml_path, out_path, class_mapping)
|