import os import random import shutil from tqdm import tqdm def split_dataset( images_dir: str, # 图像文件夹路径 labels_dir: str, # 标签文件夹路径 output_root: str, # 输出根目录 train_ratio: float = 0.7, # 训练集比例 val_ratio: float = 0.2, # 验证集比例 image_extensions: tuple = (".jpg", ".png", ".jpeg"), # 支持的图像文件格式 seed: int = 42, # 随机种子 ): """ 将数据集按照 7:2:1 的比例划分为训练集、验证集和测试集,并将图像和标签文件分别复制到相应的文件夹中。 参数: images_dir (str): 图像文件夹路径。 labels_dir (str): 标签文件夹路径。 output_root (str): 输出根目录。 train_ratio (float): 训练集比例,默认为 0.7。 val_ratio (float): 验证集比例,默认为 0.2。 image_extensions (tuple): 支持的图像文件格式,默认为 (".jpg", ".png", ".jpeg")。 seed (int): 随机种子,默认为 42。 """ # 设置随机种子 random.seed(seed) # 创建输出文件夹 output_images_dir = os.path.join(output_root, "images") output_labels_dir = os.path.join(output_root, "labels") # 创建训练集、验证集和测试集的图像文件夹 for subset in ["train", "val", "test"]: os.makedirs(os.path.join(output_images_dir, subset), exist_ok=True) # 图像文件夹 # 创建训练集和验证集的标签文件夹(指定路径) os.makedirs(os.path.join(output_labels_dir, "train_original"), exist_ok=True) # 训练集标签文件夹 os.makedirs(os.path.join(output_labels_dir, "val_original"), exist_ok=True) # 验证集标签文件夹 os.makedirs(os.path.join(output_labels_dir, "test"), exist_ok=True) # 测试集标签文件夹 # 获取所有图像文件的文件名列表 image_files = [f for f in os.listdir(images_dir) if f.endswith(image_extensions)] random.shuffle(image_files) # 随机打乱文件列表 # 计算训练集、验证集和测试集的大小 total_files = len(image_files) train_size = int(total_files * train_ratio) # 训练集大小 val_size = int(total_files * val_ratio) # 验证集大小 test_size = total_files - train_size - val_size # 测试集大小 # 复制图像和标签文件到相应的子集文件夹中 for i, image_file in enumerate(tqdm(image_files, desc="分割数据集中")): base_file_name = os.path.splitext(image_file)[0] # 获取文件名(不包括扩展名) image_path = os.path.join(images_dir, image_file) # 图像文件路径 label_path = os.path.join(labels_dir, base_file_name + ".txt") # 标签文件路径 # 根据索引判断文件应复制到训练集、验证集还是测试集 if i < train_size: # 复制到训练集 shutil.copy(image_path, os.path.join(output_images_dir, "train", image_file)) # 复制图像 shutil.copy(label_path, os.path.join(output_labels_dir, "train_original", base_file_name + ".txt")) # 复制标签 elif i < train_size + val_size: # 复制到验证集 shutil.copy(image_path, os.path.join(output_images_dir, "val", image_file)) # 复制图像 shutil.copy(label_path, os.path.join(output_labels_dir, "val_original", base_file_name + ".txt")) # 复制标签 else: # 复制到测试集 shutil.copy(image_path, os.path.join(output_images_dir, "test", image_file)) # 复制图像 shutil.copy(label_path, os.path.join(output_labels_dir, "test", base_file_name + ".txt")) # 复制标签 print(f"数据集分割完成!训练集: {train_size} 个样本,验证集: {val_size} 个样本,测试集: {test_size} 个样本。") # 示例调用 if __name__ == "__main__": # 数据集路径 images_dir = r"/home/hx/桌面/image/image" # 需要读取的所有图像文件夹路径 labels_dir = r"/home/hx/桌面/image/2" # 需要读取的所有图像与之对应的txt标签文件夹路径 # 输出路径 output_root = r"/home/hx/桌面/image" # 保存最终数据集的根目录 # 调用函数 split_dataset( images_dir=images_dir, labels_dir=labels_dir, output_root=output_root, train_ratio=0.7, # 训练集比例 val_ratio=0.2, # 验证集比例 image_extensions=(".jpg", ".png", ".jpeg"), # 支持的图像文件格式 seed=42, # 随机种子 )