first commit
This commit is contained in:
108
class_xiantiao_pc/divid_val.py
Normal file
108
class_xiantiao_pc/divid_val.py
Normal file
@ -0,0 +1,108 @@
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42):
|
||||
"""
|
||||
从 train_dir 随机抽取 ratio 比例的数据到 val_dir。
|
||||
自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。
|
||||
|
||||
Args:
|
||||
train_dir (str): 训练集路径
|
||||
val_dir (str): 验证集路径(会自动创建)
|
||||
ratio (float): 抽取比例,如 0.1 表示 10%
|
||||
seed (int): 随机种子,保证可复现
|
||||
"""
|
||||
train_path = Path(train_dir)
|
||||
val_path = Path(val_dir)
|
||||
|
||||
if not train_path.exists():
|
||||
raise FileNotFoundError(f"训练目录不存在: {train_path}")
|
||||
|
||||
# 设置随机种子
|
||||
random.seed(seed)
|
||||
|
||||
# 获取所有一级子项
|
||||
items = [p for p in train_path.iterdir()]
|
||||
|
||||
# 判断是否为分类结构:所有子项都是目录
|
||||
is_classification = all(p.is_dir() for p in items) and len(items) > 0
|
||||
|
||||
if is_classification:
|
||||
print("📁 检测到分类结构(含类别子文件夹)")
|
||||
for class_dir in items:
|
||||
class_name = class_dir.name
|
||||
src_class_dir = train_path / class_name
|
||||
dst_class_dir = val_path / class_name
|
||||
dst_class_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取该类下所有文件(只取图像,但会连带移动同名标签)
|
||||
files = [f for f in src_class_dir.iterdir() if f.is_file()]
|
||||
if not files:
|
||||
print(f" ⚠️ 类别 '{class_name}' 为空,跳过")
|
||||
continue
|
||||
|
||||
# 随机抽取
|
||||
num_val = max(1, int(len(files) * ratio)) # 至少抽1个
|
||||
val_files = random.sample(files, num_val)
|
||||
|
||||
# 移动文件(包括可能的同名标签)
|
||||
for f in val_files:
|
||||
# 移动主文件
|
||||
shutil.move(str(f), str(dst_class_dir / f.name))
|
||||
# 尝试移动同名不同扩展名的标签(如 .txt)
|
||||
for ext in ['.txt', '.xml', '.json']:
|
||||
label_file = f.with_suffix(ext)
|
||||
if label_file.exists():
|
||||
shutil.move(str(label_file), str(dst_class_dir / label_file.name))
|
||||
|
||||
print(f" ✅ 类别 '{class_name}': {len(val_files)} / {len(files)} 已移至 val")
|
||||
|
||||
else:
|
||||
print("📄 检测到平铺结构(无类别子文件夹)")
|
||||
val_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有文件,按“主文件”分组(如 img.jpg 和 img.txt 视为一组)
|
||||
all_files = [f for f in train_path.iterdir() if f.is_file()]
|
||||
# 提取所有不带扩展名的 stem(去重)
|
||||
stems = set(f.stem for f in all_files)
|
||||
file_groups = []
|
||||
for stem in stems:
|
||||
group = [f for f in all_files if f.stem == stem]
|
||||
file_groups.append(group)
|
||||
|
||||
if not file_groups:
|
||||
print("⚠️ 训练目录为空")
|
||||
return
|
||||
|
||||
# 随机抽取组
|
||||
num_val = max(1, int(len(file_groups) * ratio))
|
||||
val_groups = random.sample(file_groups, num_val)
|
||||
|
||||
# 移动每组所有文件
|
||||
for group in val_groups:
|
||||
for f in group:
|
||||
shutil.move(str(f), str(val_path / f.name))
|
||||
|
||||
print(f"✅ 平铺结构: {len(val_groups)} 组 / {len(file_groups)} 组 已移至 val")
|
||||
|
||||
print(f"\n🎉 分割完成!验证集已保存至: {val_path}")
|
||||
|
||||
|
||||
# ======================
|
||||
# 使用示例
|
||||
# ======================
|
||||
if __name__ == "__main__":
|
||||
# 修改为你自己的路径
|
||||
#TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/train"
|
||||
#VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-new/19cc/val"
|
||||
TRAIN_DIR = "/home/hx/开发/ML_xiantiao/image/datasetr1/train"
|
||||
VAL_DIR = "/home/hx/开发/ML_xiantiao/image/datasetr1/val"
|
||||
split_train_to_val(
|
||||
train_dir=TRAIN_DIR,
|
||||
val_dir=VAL_DIR,
|
||||
ratio=0.1, # 抽取 10%
|
||||
seed=25 # 随机种子
|
||||
)
|
||||
Reference in New Issue
Block a user