84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
|
|
import os
|
|||
|
|
import shutil
|
|||
|
|
import random
|
|||
|
|
|
|||
|
|
def split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1):
|
|||
|
|
"""
|
|||
|
|
将 train 目录中的数据按比例划分为 train/val/test 三部分。
|
|||
|
|
val 和 test 各自获得不同的一部分数据。
|
|||
|
|
"""
|
|||
|
|
# 检查目录
|
|||
|
|
if not os.path.exists(train_dir):
|
|||
|
|
raise FileNotFoundError(f"训练目录不存在: {train_dir}")
|
|||
|
|
|
|||
|
|
os.makedirs(val_dir, exist_ok=True)
|
|||
|
|
os.makedirs(test_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 支持的图片格式
|
|||
|
|
img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
|||
|
|
|
|||
|
|
# 获取所有图片文件
|
|||
|
|
all_files = os.listdir(train_dir)
|
|||
|
|
image_files = [f for f in all_files if os.path.splitext(f.lower())[1] in img_extensions]
|
|||
|
|
|
|||
|
|
# 配对图片和标签
|
|||
|
|
pairs = []
|
|||
|
|
for img in image_files:
|
|||
|
|
name_no_ext = os.path.splitext(img)[0]
|
|||
|
|
txt = name_no_ext + '.txt'
|
|||
|
|
if txt in all_files:
|
|||
|
|
pairs.append((img, txt))
|
|||
|
|
else:
|
|||
|
|
print(f"⚠️ 忽略 {img}:缺少对应标签文件")
|
|||
|
|
|
|||
|
|
if len(pairs) == 0:
|
|||
|
|
print("❌ 未找到有效数据对")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
total = len(pairs)
|
|||
|
|
num_test = int(total * test_ratio)
|
|||
|
|
num_val = int(total * val_ratio)
|
|||
|
|
|
|||
|
|
print(f"✅ 共找到 {total} 组有效数据")
|
|||
|
|
print(f"✅ 将移动:val={num_val}, test={num_test}")
|
|||
|
|
|
|||
|
|
# 打乱并抽样
|
|||
|
|
random.shuffle(pairs)
|
|||
|
|
test_sample = pairs[:num_test]
|
|||
|
|
val_sample = pairs[num_test:num_test + num_val]
|
|||
|
|
# 剩下的留在 train
|
|||
|
|
|
|||
|
|
# 移动到 test
|
|||
|
|
for img, txt in test_sample:
|
|||
|
|
try:
|
|||
|
|
shutil.move(os.path.join(train_dir, img), os.path.join(test_dir, img))
|
|||
|
|
shutil.move(os.path.join(train_dir, txt), os.path.join(test_dir, txt))
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 移动到 test 失败: {img}, {e}")
|
|||
|
|
|
|||
|
|
# 移动到 val
|
|||
|
|
for img, txt in val_sample:
|
|||
|
|
try:
|
|||
|
|
shutil.move(os.path.join(train_dir, img), os.path.join(val_dir, img))
|
|||
|
|
shutil.move(os.path.join(train_dir, txt), os.path.join(val_dir, txt))
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 移动到 val 失败: {img}, {e}")
|
|||
|
|
|
|||
|
|
print(f"\n✅ 划分完成!")
|
|||
|
|
print(f" train 保留: {total - num_val - num_test}")
|
|||
|
|
print(f" val : {num_val}")
|
|||
|
|
print(f" test : {num_test}")
|
|||
|
|
print(f" 所有 val 和 test 数据已从 train 移出。")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========================
|
|||
|
|
# 使用示例
|
|||
|
|
# ========================
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
train_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/train"
|
|||
|
|
val_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/val"
|
|||
|
|
test_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/test"
|
|||
|
|
|
|||
|
|
split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1)
|
|||
|
|
# ========================
|