95 lines
4.4 KiB
Python
95 lines
4.4 KiB
Python
import os
|
|
import random
|
|
import shutil
|
|
from tqdm import tqdm
|
|
|
|
def split_dataset(
|
|
images_dir: str, # 图像文件夹路径
|
|
labels_dir: str, # 标签文件夹路径
|
|
output_root: str, # 输出根目录
|
|
train_ratio: float = 0.7, # 训练集比例
|
|
val_ratio: float = 0.2, # 验证集比例
|
|
image_extensions: tuple = (".jpg", ".png", ".jpeg"), # 支持的图像文件格式
|
|
seed: int = 42, # 随机种子
|
|
):
|
|
"""
|
|
将数据集按照 7:2:1 的比例划分为训练集、验证集和测试集,并将图像和标签文件分别复制到相应的文件夹中。
|
|
|
|
参数:
|
|
images_dir (str): 图像文件夹路径。
|
|
labels_dir (str): 标签文件夹路径。
|
|
output_root (str): 输出根目录。
|
|
train_ratio (float): 训练集比例,默认为 0.7。
|
|
val_ratio (float): 验证集比例,默认为 0.2。
|
|
image_extensions (tuple): 支持的图像文件格式,默认为 (".jpg", ".png", ".jpeg")。
|
|
seed (int): 随机种子,默认为 42。
|
|
"""
|
|
# 设置随机种子
|
|
random.seed(seed)
|
|
|
|
# 创建输出文件夹
|
|
output_images_dir = os.path.join(output_root, "images")
|
|
output_labels_dir = os.path.join(output_root, "labels")
|
|
|
|
# 创建训练集、验证集和测试集的图像文件夹
|
|
for subset in ["train", "val", "test"]:
|
|
os.makedirs(os.path.join(output_images_dir, subset), exist_ok=True) # 图像文件夹
|
|
|
|
# 创建训练集和验证集的标签文件夹(指定路径)
|
|
os.makedirs(os.path.join(output_labels_dir, "train_original"), exist_ok=True) # 训练集标签文件夹
|
|
os.makedirs(os.path.join(output_labels_dir, "val_original"), exist_ok=True) # 验证集标签文件夹
|
|
os.makedirs(os.path.join(output_labels_dir, "test"), exist_ok=True) # 测试集标签文件夹
|
|
|
|
# 获取所有图像文件的文件名列表
|
|
image_files = [f for f in os.listdir(images_dir) if f.endswith(image_extensions)]
|
|
random.shuffle(image_files) # 随机打乱文件列表
|
|
|
|
# 计算训练集、验证集和测试集的大小
|
|
total_files = len(image_files)
|
|
train_size = int(total_files * train_ratio) # 训练集大小
|
|
val_size = int(total_files * val_ratio) # 验证集大小
|
|
test_size = total_files - train_size - val_size # 测试集大小
|
|
|
|
# 复制图像和标签文件到相应的子集文件夹中
|
|
for i, image_file in enumerate(tqdm(image_files, desc="分割数据集中")):
|
|
base_file_name = os.path.splitext(image_file)[0] # 获取文件名(不包括扩展名)
|
|
image_path = os.path.join(images_dir, image_file) # 图像文件路径
|
|
label_path = os.path.join(labels_dir, base_file_name + ".txt") # 标签文件路径
|
|
|
|
# 根据索引判断文件应复制到训练集、验证集还是测试集
|
|
if i < train_size:
|
|
# 复制到训练集
|
|
shutil.copy(image_path, os.path.join(output_images_dir, "train", image_file)) # 复制图像
|
|
shutil.copy(label_path, os.path.join(output_labels_dir, "train_original", base_file_name + ".txt")) # 复制标签
|
|
elif i < train_size + val_size:
|
|
# 复制到验证集
|
|
shutil.copy(image_path, os.path.join(output_images_dir, "val", image_file)) # 复制图像
|
|
shutil.copy(label_path, os.path.join(output_labels_dir, "val_original", base_file_name + ".txt")) # 复制标签
|
|
else:
|
|
# 复制到测试集
|
|
shutil.copy(image_path, os.path.join(output_images_dir, "test", image_file)) # 复制图像
|
|
shutil.copy(label_path, os.path.join(output_labels_dir, "test", base_file_name + ".txt")) # 复制标签
|
|
|
|
print(f"数据集分割完成!训练集: {train_size} 个样本,验证集: {val_size} 个样本,测试集: {test_size} 个样本。")
|
|
|
|
|
|
# 示例调用
|
|
if __name__ == "__main__":
|
|
# 数据集路径
|
|
images_dir = r"/home/hx/桌面/image/image" # 需要读取的所有图像文件夹路径
|
|
labels_dir = r"/home/hx/桌面/image/2" # 需要读取的所有图像与之对应的txt标签文件夹路径
|
|
|
|
# 输出路径
|
|
output_root = r"/home/hx/桌面/image" # 保存最终数据集的根目录
|
|
|
|
# 调用函数
|
|
split_dataset(
|
|
images_dir=images_dir,
|
|
labels_dir=labels_dir,
|
|
output_root=output_root,
|
|
train_ratio=0.7, # 训练集比例
|
|
val_ratio=0.2, # 验证集比例
|
|
image_extensions=(".jpg", ".png", ".jpeg"), # 支持的图像文件格式
|
|
seed=42, # 随机种子
|
|
)
|