import os import shutil import random from pathlib import Path def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42): """ 从 train_dir 随机抽取 ratio 比例的 **带标签图像** 到 val_dir。 自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。 Args: train_dir (str): 训练集路径 val_dir (str): 验证集路径(会自动创建) ratio (float): 抽取比例,如 0.1 表示 10% seed (int): 随机种子,保证可复现 """ train_path = Path(train_dir) val_path = Path(val_dir) if not train_path.exists(): raise FileNotFoundError(f"训练目录不存在: {train_path}") # 设置随机种子 random.seed(seed) # 获取所有一级子项 items = [p for p in train_path.iterdir()] # 判断是否为分类结构:所有子项都是目录且非空 is_classification = all(p.is_dir() for p in items) and len(items) > 0 # 定义图像扩展名 IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif'} if is_classification: print("📁 检测到分类结构(含类别子文件夹)") for class_dir in items: class_name = class_dir.name src_class_dir = train_path / class_name dst_class_dir = val_path / class_name dst_class_dir.mkdir(parents=True, exist_ok=True) # 只找图像文件 image_files = [ f for f in src_class_dir.iterdir() if f.is_file() and f.suffix.lower() in IMG_EXTENSIONS ] if not image_files: print(f" ⚠️ 类别 '{class_name}' 中无图像文件,跳过") continue num_val = max(1, int(len(image_files) * ratio)) val_images = random.sample(image_files, num_val) for img in val_images: # 移动图像 shutil.move(str(img), str(dst_class_dir / img.name)) # 移动同名 .txt 标签 txt_file = img.with_suffix('.txt') if txt_file.exists(): shutil.move(str(txt_file), str(dst_class_dir / txt_file.name)) print(f" ✅ 类别 '{class_name}': {len(val_images)} 张图像已移至 val") else: print("📄 检测到平铺结构(YOLO格式:图像 + 同名 .txt 标签)") val_path.mkdir(parents=True, exist_ok=True) # 只收集图像文件(作为采样单元) image_files = [ f for f in train_path.iterdir() if f.is_file() and f.suffix.lower() in IMG_EXTENSIONS ] if not image_files: print("⚠️ 训练目录中未找到任何图像文件(支持格式: jpg, png 等)") return # 随机抽取图像 num_val = max(1, int(len(image_files) * ratio)) val_images = random.sample(image_files, num_val) # 移动选中的图像及其标签 for img in val_images: # 移动图像 shutil.move(str(img), str(val_path / img.name)) # 移动同名 .txt txt_file = img.with_suffix('.txt') if txt_file.exists(): shutil.move(str(txt_file), str(val_path / txt_file.name)) print(f"✅ 平铺结构: 已移动 {len(val_images)} 张图像及其标签到 {val_path}") print(f"\n🎉 分割完成!验证集已保存至: {val_path}") # ====================== # 使用示例 # ====================== if __name__ == "__main__": TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train" VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/val1" split_train_to_val( train_dir=TRAIN_DIR, val_dir=VAL_DIR, ratio=0.1, seed=42 )