84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
import os
|
||
import shutil
|
||
import random
|
||
|
||
def split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1):
|
||
"""
|
||
将 train 目录中的数据按比例划分为 train/val/test 三部分。
|
||
val 和 test 各自获得不同的一部分数据。
|
||
"""
|
||
# 检查目录
|
||
if not os.path.exists(train_dir):
|
||
raise FileNotFoundError(f"训练目录不存在: {train_dir}")
|
||
|
||
os.makedirs(val_dir, exist_ok=True)
|
||
os.makedirs(test_dir, exist_ok=True)
|
||
|
||
# 支持的图片格式
|
||
img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
||
|
||
# 获取所有图片文件
|
||
all_files = os.listdir(train_dir)
|
||
image_files = [f for f in all_files if os.path.splitext(f.lower())[1] in img_extensions]
|
||
|
||
# 配对图片和标签
|
||
pairs = []
|
||
for img in image_files:
|
||
name_no_ext = os.path.splitext(img)[0]
|
||
txt = name_no_ext + '.txt'
|
||
if txt in all_files:
|
||
pairs.append((img, txt))
|
||
else:
|
||
print(f"⚠️ 忽略 {img}:缺少对应标签文件")
|
||
|
||
if len(pairs) == 0:
|
||
print("❌ 未找到有效数据对")
|
||
return
|
||
|
||
total = len(pairs)
|
||
num_test = int(total * test_ratio)
|
||
num_val = int(total * val_ratio)
|
||
|
||
print(f"✅ 共找到 {total} 组有效数据")
|
||
print(f"✅ 将移动:val={num_val}, test={num_test}")
|
||
|
||
# 打乱并抽样
|
||
random.shuffle(pairs)
|
||
test_sample = pairs[:num_test]
|
||
val_sample = pairs[num_test:num_test + num_val]
|
||
# 剩下的留在 train
|
||
|
||
# 移动到 test
|
||
for img, txt in test_sample:
|
||
try:
|
||
shutil.move(os.path.join(train_dir, img), os.path.join(test_dir, img))
|
||
shutil.move(os.path.join(train_dir, txt), os.path.join(test_dir, txt))
|
||
except Exception as e:
|
||
print(f"❌ 移动到 test 失败: {img}, {e}")
|
||
|
||
# 移动到 val
|
||
for img, txt in val_sample:
|
||
try:
|
||
shutil.move(os.path.join(train_dir, img), os.path.join(val_dir, img))
|
||
shutil.move(os.path.join(train_dir, txt), os.path.join(val_dir, txt))
|
||
except Exception as e:
|
||
print(f"❌ 移动到 val 失败: {img}, {e}")
|
||
|
||
print(f"\n✅ 划分完成!")
|
||
print(f" train 保留: {total - num_val - num_test}")
|
||
print(f" val : {num_val}")
|
||
print(f" test : {num_test}")
|
||
print(f" 所有 val 和 test 数据已从 train 移出。")
|
||
|
||
|
||
# ========================
|
||
# 使用示例
|
||
# ========================
|
||
if __name__ == "__main__":
|
||
train_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/train"
|
||
val_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/val"
|
||
test_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/test"
|
||
|
||
split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1)
|
||
# ========================
|