import os import shutil import random def split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1): """ 将 train 目录中的数据按比例划分为 train/val/test 三部分。 val 和 test 各自获得不同的一部分数据。 """ # 检查目录 if not os.path.exists(train_dir): raise FileNotFoundError(f"训练目录不存在: {train_dir}") os.makedirs(val_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True) # 支持的图片格式 img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} # 获取所有图片文件 all_files = os.listdir(train_dir) image_files = [f for f in all_files if os.path.splitext(f.lower())[1] in img_extensions] # 配对图片和标签 pairs = [] for img in image_files: name_no_ext = os.path.splitext(img)[0] txt = name_no_ext + '.txt' if txt in all_files: pairs.append((img, txt)) else: print(f"⚠️ 忽略 {img}:缺少对应标签文件") if len(pairs) == 0: print("❌ 未找到有效数据对") return total = len(pairs) num_test = int(total * test_ratio) num_val = int(total * val_ratio) print(f"✅ 共找到 {total} 组有效数据") print(f"✅ 将移动:val={num_val}, test={num_test}") # 打乱并抽样 random.shuffle(pairs) test_sample = pairs[:num_test] val_sample = pairs[num_test:num_test + num_val] # 剩下的留在 train # 移动到 test for img, txt in test_sample: try: shutil.move(os.path.join(train_dir, img), os.path.join(test_dir, img)) shutil.move(os.path.join(train_dir, txt), os.path.join(test_dir, txt)) except Exception as e: print(f"❌ 移动到 test 失败: {img}, {e}") # 移动到 val for img, txt in val_sample: try: shutil.move(os.path.join(train_dir, img), os.path.join(val_dir, img)) shutil.move(os.path.join(train_dir, txt), os.path.join(val_dir, txt)) except Exception as e: print(f"❌ 移动到 val 失败: {img}, {e}") print(f"\n✅ 划分完成!") print(f" train 保留: {total - num_val - num_test}") print(f" val : {num_val}") print(f" test : {num_test}") print(f" 所有 val 和 test 数据已从 train 移出。") # ======================== # 使用示例 # ======================== if __name__ == "__main__": train_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/train" val_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/val" test_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/test" split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1) # ========================