Files
zjsh_yolov11/yolo11_obb/divid_dataset.py

84 lines
2.8 KiB
Python
Raw Normal View History

2025-10-21 14:11:52 +08:00
import os
import shutil
import random
def split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1):
"""
train 目录中的数据按比例划分为 train/val/test 三部分
val test 各自获得不同的一部分数据
"""
# 检查目录
if not os.path.exists(train_dir):
raise FileNotFoundError(f"训练目录不存在: {train_dir}")
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# 支持的图片格式
img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
# 获取所有图片文件
all_files = os.listdir(train_dir)
image_files = [f for f in all_files if os.path.splitext(f.lower())[1] in img_extensions]
# 配对图片和标签
pairs = []
for img in image_files:
name_no_ext = os.path.splitext(img)[0]
txt = name_no_ext + '.txt'
if txt in all_files:
pairs.append((img, txt))
else:
print(f"⚠️ 忽略 {img}:缺少对应标签文件")
if len(pairs) == 0:
print("❌ 未找到有效数据对")
return
total = len(pairs)
num_test = int(total * test_ratio)
num_val = int(total * val_ratio)
print(f"✅ 共找到 {total} 组有效数据")
print(f"✅ 将移动val={num_val}, test={num_test}")
# 打乱并抽样
random.shuffle(pairs)
test_sample = pairs[:num_test]
val_sample = pairs[num_test:num_test + num_val]
# 剩下的留在 train
# 移动到 test
for img, txt in test_sample:
try:
shutil.move(os.path.join(train_dir, img), os.path.join(test_dir, img))
shutil.move(os.path.join(train_dir, txt), os.path.join(test_dir, txt))
except Exception as e:
print(f"❌ 移动到 test 失败: {img}, {e}")
# 移动到 val
for img, txt in val_sample:
try:
shutil.move(os.path.join(train_dir, img), os.path.join(val_dir, img))
shutil.move(os.path.join(train_dir, txt), os.path.join(val_dir, txt))
except Exception as e:
print(f"❌ 移动到 val 失败: {img}, {e}")
print(f"\n✅ 划分完成!")
print(f" train 保留: {total - num_val - num_test}")
print(f" val : {num_val}")
print(f" test : {num_test}")
print(f" 所有 val 和 test 数据已从 train 移出。")
# ========================
# 使用示例
# ========================
if __name__ == "__main__":
train_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/train"
val_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/val"
test_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/test"
split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1)
# ========================