109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
|
|
import os
|
|||
|
|
import shutil
|
|||
|
|
import random
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def split_train_to_val(train_dir, val_dir, ratio=0.1, seed=42):
|
|||
|
|
"""
|
|||
|
|
从 train_dir 随机抽取 ratio 比例的 **带标签图像** 到 val_dir。
|
|||
|
|
自动判断是分类结构(有子文件夹)还是平铺结构(无子文件夹)。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
train_dir (str): 训练集路径
|
|||
|
|
val_dir (str): 验证集路径(会自动创建)
|
|||
|
|
ratio (float): 抽取比例,如 0.1 表示 10%
|
|||
|
|
seed (int): 随机种子,保证可复现
|
|||
|
|
"""
|
|||
|
|
train_path = Path(train_dir)
|
|||
|
|
val_path = Path(val_dir)
|
|||
|
|
|
|||
|
|
if not train_path.exists():
|
|||
|
|
raise FileNotFoundError(f"训练目录不存在: {train_path}")
|
|||
|
|
|
|||
|
|
# 设置随机种子
|
|||
|
|
random.seed(seed)
|
|||
|
|
|
|||
|
|
# 获取所有一级子项
|
|||
|
|
items = [p for p in train_path.iterdir()]
|
|||
|
|
|
|||
|
|
# 判断是否为分类结构:所有子项都是目录且非空
|
|||
|
|
is_classification = all(p.is_dir() for p in items) and len(items) > 0
|
|||
|
|
|
|||
|
|
# 定义图像扩展名
|
|||
|
|
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif'}
|
|||
|
|
|
|||
|
|
if is_classification:
|
|||
|
|
print("📁 检测到分类结构(含类别子文件夹)")
|
|||
|
|
for class_dir in items:
|
|||
|
|
class_name = class_dir.name
|
|||
|
|
src_class_dir = train_path / class_name
|
|||
|
|
dst_class_dir = val_path / class_name
|
|||
|
|
dst_class_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 只找图像文件
|
|||
|
|
image_files = [
|
|||
|
|
f for f in src_class_dir.iterdir()
|
|||
|
|
if f.is_file() and f.suffix.lower() in IMG_EXTENSIONS
|
|||
|
|
]
|
|||
|
|
if not image_files:
|
|||
|
|
print(f" ⚠️ 类别 '{class_name}' 中无图像文件,跳过")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
num_val = max(1, int(len(image_files) * ratio))
|
|||
|
|
val_images = random.sample(image_files, num_val)
|
|||
|
|
|
|||
|
|
for img in val_images:
|
|||
|
|
# 移动图像
|
|||
|
|
shutil.move(str(img), str(dst_class_dir / img.name))
|
|||
|
|
# 移动同名 .txt 标签
|
|||
|
|
txt_file = img.with_suffix('.txt')
|
|||
|
|
if txt_file.exists():
|
|||
|
|
shutil.move(str(txt_file), str(dst_class_dir / txt_file.name))
|
|||
|
|
|
|||
|
|
print(f" ✅ 类别 '{class_name}': {len(val_images)} 张图像已移至 val")
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
print("📄 检测到平铺结构(YOLO格式:图像 + 同名 .txt 标签)")
|
|||
|
|
val_path.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 只收集图像文件(作为采样单元)
|
|||
|
|
image_files = [
|
|||
|
|
f for f in train_path.iterdir()
|
|||
|
|
if f.is_file() and f.suffix.lower() in IMG_EXTENSIONS
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if not image_files:
|
|||
|
|
print("⚠️ 训练目录中未找到任何图像文件(支持格式: jpg, png 等)")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 随机抽取图像
|
|||
|
|
num_val = max(1, int(len(image_files) * ratio))
|
|||
|
|
val_images = random.sample(image_files, num_val)
|
|||
|
|
|
|||
|
|
# 移动选中的图像及其标签
|
|||
|
|
for img in val_images:
|
|||
|
|
# 移动图像
|
|||
|
|
shutil.move(str(img), str(val_path / img.name))
|
|||
|
|
# 移动同名 .txt
|
|||
|
|
txt_file = img.with_suffix('.txt')
|
|||
|
|
if txt_file.exists():
|
|||
|
|
shutil.move(str(txt_file), str(val_path / txt_file.name))
|
|||
|
|
|
|||
|
|
print(f"✅ 平铺结构: 已移动 {len(val_images)} 张图像及其标签到 {val_path}")
|
|||
|
|
|
|||
|
|
print(f"\n🎉 分割完成!验证集已保存至: {val_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ======================
|
|||
|
|
# 使用示例
|
|||
|
|
# ======================
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
TRAIN_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/train"
|
|||
|
|
VAL_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/point2/val1"
|
|||
|
|
split_train_to_val(
|
|||
|
|
train_dir=TRAIN_DIR,
|
|||
|
|
val_dir=VAL_DIR,
|
|||
|
|
ratio=0.1,
|
|||
|
|
seed=42
|
|||
|
|
)
|