2025-12-11 08:37:09 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
import random
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42):
|
|
|
|
|
|
"""
|
|
|
|
|
|
从 train_dir 随机抽取 ratio 比例的数据到 val_dir。
|
|
|
|
|
|
自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
train_dir (str): 训练集路径
|
|
|
|
|
|
val_dir (str): 验证集路径(会自动创建)
|
|
|
|
|
|
ratio (float): 抽取比例,如 0.1 表示 10%
|
|
|
|
|
|
seed (int): 随机种子,保证可复现
|
|
|
|
|
|
"""
|
|
|
|
|
|
train_path = Path(train_dir)
|
|
|
|
|
|
val_path = Path(val_dir)
|
|
|
|
|
|
|
|
|
|
|
|
if not train_path.exists():
|
|
|
|
|
|
raise FileNotFoundError(f"训练目录不存在: {train_path}")
|
|
|
|
|
|
|
|
|
|
|
|
# 设置随机种子
|
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有一级子项
|
|
|
|
|
|
items = [p for p in train_path.iterdir()]
|
|
|
|
|
|
|
|
|
|
|
|
# 判断是否为分类结构:所有子项都是目录
|
|
|
|
|
|
is_classification = all(p.is_dir() for p in items) and len(items) > 0
|
|
|
|
|
|
|
|
|
|
|
|
if is_classification:
|
|
|
|
|
|
print("📁 检测到分类结构(含类别子文件夹)")
|
|
|
|
|
|
for class_dir in items:
|
|
|
|
|
|
class_name = class_dir.name
|
|
|
|
|
|
src_class_dir = train_path / class_name
|
|
|
|
|
|
dst_class_dir = val_path / class_name
|
|
|
|
|
|
dst_class_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取该类下所有文件(只取图像,但会连带移动同名标签)
|
|
|
|
|
|
files = [f for f in src_class_dir.iterdir() if f.is_file()]
|
|
|
|
|
|
if not files:
|
|
|
|
|
|
print(f" ⚠️ 类别 '{class_name}' 为空,跳过")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 随机抽取
|
|
|
|
|
|
num_val = max(1, int(len(files) * ratio)) # 至少抽1个
|
|
|
|
|
|
val_files = random.sample(files, num_val)
|
|
|
|
|
|
|
|
|
|
|
|
# 移动文件(包括可能的同名标签)
|
|
|
|
|
|
for f in val_files:
|
|
|
|
|
|
# 移动主文件
|
|
|
|
|
|
shutil.move(str(f), str(dst_class_dir / f.name))
|
|
|
|
|
|
# 尝试移动同名不同扩展名的标签(如 .txt)
|
|
|
|
|
|
for ext in ['.txt', '.xml', '.json']:
|
|
|
|
|
|
label_file = f.with_suffix(ext)
|
|
|
|
|
|
if label_file.exists():
|
|
|
|
|
|
shutil.move(str(label_file), str(dst_class_dir / label_file.name))
|
|
|
|
|
|
|
|
|
|
|
|
print(f" ✅ 类别 '{class_name}': {len(val_files)} / {len(files)} 已移至 val")
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("📄 检测到平铺结构(无类别子文件夹)")
|
|
|
|
|
|
val_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有文件,按“主文件”分组(如 img.jpg 和 img.txt 视为一组)
|
|
|
|
|
|
all_files = [f for f in train_path.iterdir() if f.is_file()]
|
|
|
|
|
|
# 提取所有不带扩展名的 stem(去重)
|
|
|
|
|
|
stems = set(f.stem for f in all_files)
|
|
|
|
|
|
file_groups = []
|
|
|
|
|
|
for stem in stems:
|
|
|
|
|
|
group = [f for f in all_files if f.stem == stem]
|
|
|
|
|
|
file_groups.append(group)
|
|
|
|
|
|
|
|
|
|
|
|
if not file_groups:
|
|
|
|
|
|
print("⚠️ 训练目录为空")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 随机抽取组
|
|
|
|
|
|
num_val = max(1, int(len(file_groups) * ratio))
|
|
|
|
|
|
val_groups = random.sample(file_groups, num_val)
|
|
|
|
|
|
|
|
|
|
|
|
# 移动每组所有文件
|
|
|
|
|
|
for group in val_groups:
|
|
|
|
|
|
for f in group:
|
|
|
|
|
|
shutil.move(str(f), str(val_path / f.name))
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✅ 平铺结构: {len(val_groups)} 组 / {len(file_groups)} 组 已移至 val")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n🎉 分割完成!验证集已保存至: {val_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ======================
|
|
|
|
|
|
# 使用示例
|
|
|
|
|
|
# ======================
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
# 修改为你自己的路径
|
|
|
|
|
|
#TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/train"
|
|
|
|
|
|
#VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/val"
|
2026-03-10 13:58:21 +08:00
|
|
|
|
TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/charge/train"
|
|
|
|
|
|
VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/charge/val"
|
2025-12-11 08:37:09 +08:00
|
|
|
|
split_train_to_val(
|
|
|
|
|
|
train_dir=TRAIN_DIR,
|
|
|
|
|
|
val_dir=VAL_DIR,
|
|
|
|
|
|
ratio=0.1, # 抽取 10%
|
2026-03-10 13:58:21 +08:00
|
|
|
|
seed=58 # 随机种子al
|
2025-12-11 08:37:09 +08:00
|
|
|
|
)
|