import os import shutil import random from pathlib import Path def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42): """ 从 train_dir 随机抽取 ratio 比例的数据到 val_dir。 自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。 Args: train_dir (str): 训练集路径 val_dir (str): 验证集路径(会自动创建) ratio (float): 抽取比例,如 0.1 表示 10% seed (int): 随机种子,保证可复现 """ train_path = Path(train_dir) val_path = Path(val_dir) if not train_path.exists(): raise FileNotFoundError(f"训练目录不存在: {train_path}") # 设置随机种子 random.seed(seed) # 获取所有一级子项 items = [p for p in train_path.iterdir()] # 判断是否为分类结构:所有子项都是目录 is_classification = all(p.is_dir() for p in items) and len(items) > 0 if is_classification: print("📁 检测到分类结构(含类别子文件夹)") for class_dir in items: class_name = class_dir.name src_class_dir = train_path / class_name dst_class_dir = val_path / class_name dst_class_dir.mkdir(parents=True, exist_ok=True) # 获取该类下所有文件(只取图像,但会连带移动同名标签) files = [f for f in src_class_dir.iterdir() if f.is_file()] if not files: print(f" ⚠️ 类别 '{class_name}' 为空,跳过") continue # 随机抽取 num_val = max(1, int(len(files) * ratio)) # 至少抽1个 val_files = random.sample(files, num_val) # 移动文件(包括可能的同名标签) for f in val_files: # 移动主文件 shutil.move(str(f), str(dst_class_dir / f.name)) # 尝试移动同名不同扩展名的标签(如 .txt) for ext in ['.txt', '.xml', '.json']: label_file = f.with_suffix(ext) if label_file.exists(): shutil.move(str(label_file), str(dst_class_dir / label_file.name)) print(f" ✅ 类别 '{class_name}': {len(val_files)} / {len(files)} 已移至 val") else: print("📄 检测到平铺结构(无类别子文件夹)") val_path.mkdir(parents=True, exist_ok=True) # 获取所有文件,按“主文件”分组(如 img.jpg 和 img.txt 视为一组) all_files = [f for f in train_path.iterdir() if f.is_file()] # 提取所有不带扩展名的 stem(去重) stems = set(f.stem for f in all_files) file_groups = [] for stem in stems: group = [f for f in all_files if f.stem == stem] file_groups.append(group) if not file_groups: print("⚠️ 训练目录为空") return # 随机抽取组 num_val = max(1, int(len(file_groups) * ratio)) val_groups = random.sample(file_groups, num_val) # 移动每组所有文件 for group in val_groups: for f in group: shutil.move(str(f), str(val_path / f.name)) print(f"✅ 平铺结构: {len(val_groups)} 组 / {len(file_groups)} 组 已移至 val") print(f"\n🎉 分割完成!验证集已保存至: {val_path}") # ====================== # 使用示例 # ====================== if __name__ == "__main__": # 修改为你自己的路径 #TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/train" #VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/val" TRAIN_DIR = "/home/hx/开发/ML_xiantiao/image/datasetr1/train" VAL_DIR = "/home/hx/开发/ML_xiantiao/image/datasetr1/val" split_train_to_val( train_dir=TRAIN_DIR, val_dir=VAL_DIR, ratio=0.1, # 抽取 10% seed=25 # 随机种子 )