Files
xiantiao_CV/class_xiantiao_pc/divid_val.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

108 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
import random
from pathlib import Path
def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42):
"""
从 train_dir 随机抽取 ratio 比例的数据到 val_dir。
自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。
Args:
train_dir (str): 训练集路径
val_dir (str): 验证集路径(会自动创建)
ratio (float): 抽取比例,如 0.1 表示 10%
seed (int): 随机种子,保证可复现
"""
train_path = Path(train_dir)
val_path = Path(val_dir)
if not train_path.exists():
raise FileNotFoundError(f"训练目录不存在: {train_path}")
# 设置随机种子
random.seed(seed)
# 获取所有一级子项
items = [p for p in train_path.iterdir()]
# 判断是否为分类结构:所有子项都是目录
is_classification = all(p.is_dir() for p in items) and len(items) > 0
if is_classification:
print("📁 检测到分类结构(含类别子文件夹)")
for class_dir in items:
class_name = class_dir.name
src_class_dir = train_path / class_name
dst_class_dir = val_path / class_name
dst_class_dir.mkdir(parents=True, exist_ok=True)
# 获取该类下所有文件(只取图像,但会连带移动同名标签)
files = [f for f in src_class_dir.iterdir() if f.is_file()]
if not files:
print(f" ⚠️ 类别 '{class_name}' 为空,跳过")
continue
# 随机抽取
num_val = max(1, int(len(files) * ratio)) # 至少抽1个
val_files = random.sample(files, num_val)
# 移动文件(包括可能的同名标签)
for f in val_files:
# 移动主文件
shutil.move(str(f), str(dst_class_dir / f.name))
# 尝试移动同名不同扩展名的标签(如 .txt
for ext in ['.txt', '.xml', '.json']:
label_file = f.with_suffix(ext)
if label_file.exists():
shutil.move(str(label_file), str(dst_class_dir / label_file.name))
print(f" ✅ 类别 '{class_name}': {len(val_files)} / {len(files)} 已移至 val")
else:
print("📄 检测到平铺结构(无类别子文件夹)")
val_path.mkdir(parents=True, exist_ok=True)
# 获取所有文件,按“主文件”分组(如 img.jpg 和 img.txt 视为一组)
all_files = [f for f in train_path.iterdir() if f.is_file()]
# 提取所有不带扩展名的 stem去重
stems = set(f.stem for f in all_files)
file_groups = []
for stem in stems:
group = [f for f in all_files if f.stem == stem]
file_groups.append(group)
if not file_groups:
print("⚠️ 训练目录为空")
return
# 随机抽取组
num_val = max(1, int(len(file_groups) * ratio))
val_groups = random.sample(file_groups, num_val)
# 移动每组所有文件
for group in val_groups:
for f in group:
shutil.move(str(f), str(val_path / f.name))
print(f"✅ 平铺结构: {len(val_groups)} 组 / {len(file_groups)} 组 已移至 val")
print(f"\n🎉 分割完成!验证集已保存至: {val_path}")
# ======================
# 使用示例
# ======================
if __name__ == "__main__":
# 修改为你自己的路径
#TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/train"
#VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/val"
TRAIN_DIR = "/home/hx/开发/ML_xiantiao/image/datasetr1/train"
VAL_DIR = "/home/hx/开发/ML_xiantiao/image/datasetr1/val"
split_train_to_val(
train_dir=TRAIN_DIR,
val_dir=VAL_DIR,
ratio=0.1, # 抽取 10%
seed=25 # 随机种子
)