109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
import os
|
||
import shutil
|
||
import random
|
||
from pathlib import Path
|
||
|
||
|
||
def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42):
|
||
"""
|
||
从 train_dir 随机抽取 ratio 比例的 **带标签图像** 到 val_dir。
|
||
自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。
|
||
|
||
Args:
|
||
train_dir (str): 训练集路径
|
||
val_dir (str): 验证集路径(会自动创建)
|
||
ratio (float): 抽取比例,如 0.1 表示 10%
|
||
seed (int): 随机种子,保证可复现
|
||
"""
|
||
train_path = Path(train_dir)
|
||
val_path = Path(val_dir)
|
||
|
||
if not train_path.exists():
|
||
raise FileNotFoundError(f"训练目录不存在: {train_path}")
|
||
|
||
# 设置随机种子
|
||
random.seed(seed)
|
||
|
||
# 获取所有一级子项
|
||
items = [p for p in train_path.iterdir()]
|
||
|
||
# 判断是否为分类结构:所有子项都是目录且非空
|
||
is_classification = all(p.is_dir() for p in items) and len(items) > 0
|
||
|
||
# 定义图像扩展名
|
||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif'}
|
||
|
||
if is_classification:
|
||
print("📁 检测到分类结构(含类别子文件夹)")
|
||
for class_dir in items:
|
||
class_name = class_dir.name
|
||
src_class_dir = train_path / class_name
|
||
dst_class_dir = val_path / class_name
|
||
dst_class_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 只找图像文件
|
||
image_files = [
|
||
f for f in src_class_dir.iterdir()
|
||
if f.is_file() and f.suffix.lower() in IMG_EXTENSIONS
|
||
]
|
||
if not image_files:
|
||
print(f" ⚠️ 类别 '{class_name}' 中无图像文件,跳过")
|
||
continue
|
||
|
||
num_val = max(1, int(len(image_files) * ratio))
|
||
val_images = random.sample(image_files, num_val)
|
||
|
||
for img in val_images:
|
||
# 移动图像
|
||
shutil.move(str(img), str(dst_class_dir / img.name))
|
||
# 移动同名 .txt 标签
|
||
txt_file = img.with_suffix('.txt')
|
||
if txt_file.exists():
|
||
shutil.move(str(txt_file), str(dst_class_dir / txt_file.name))
|
||
|
||
print(f" ✅ 类别 '{class_name}': {len(val_images)} 张图像已移至 val")
|
||
|
||
else:
|
||
print("📄 检测到平铺结构(YOLO格式:图像 + 同名 .txt 标签)")
|
||
val_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 只收集图像文件(作为采样单元)
|
||
image_files = [
|
||
f for f in train_path.iterdir()
|
||
if f.is_file() and f.suffix.lower() in IMG_EXTENSIONS
|
||
]
|
||
|
||
if not image_files:
|
||
print("⚠️ 训练目录中未找到任何图像文件(支持格式: jpg, png 等)")
|
||
return
|
||
|
||
# 随机抽取图像
|
||
num_val = max(1, int(len(image_files) * ratio))
|
||
val_images = random.sample(image_files, num_val)
|
||
|
||
# 移动选中的图像及其标签
|
||
for img in val_images:
|
||
# 移动图像
|
||
shutil.move(str(img), str(val_path / img.name))
|
||
# 移动同名 .txt
|
||
txt_file = img.with_suffix('.txt')
|
||
if txt_file.exists():
|
||
shutil.move(str(txt_file), str(val_path / txt_file.name))
|
||
|
||
print(f"✅ 平铺结构: 已移动 {len(val_images)} 张图像及其标签到 {val_path}")
|
||
|
||
print(f"\n🎉 分割完成!验证集已保存至: {val_path}")
|
||
|
||
|
||
# ======================
|
||
# 使用示例
|
||
# ======================
|
||
if __name__ == "__main__":
|
||
TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train"
|
||
VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/val1"
|
||
split_train_to_val(
|
||
train_dir=TRAIN_DIR,
|
||
val_dir=VAL_DIR,
|
||
ratio=0.1,
|
||
seed=42
|
||
) |