Files
zjsh_yolov11/yolo11_obb/d_data.py

95 lines
4.5 KiB
Python
Raw Permalink Normal View History

2025-08-18 15:13:54 +08:00
import os
import random
import shutil
from tqdm import tqdm
def split_dataset(
images_dir: str, # 图像文件夹路径
labels_dir: str, # 标签文件夹路径
output_root: str, # 输出根目录
train_ratio: float = 0.7, # 训练集比例
val_ratio: float = 0.2, # 验证集比例
image_extensions: tuple = (".jpg", ".png", ".jpeg"), # 支持的图像文件格式
seed: int = 42, # 随机种子
):
"""
将数据集按照 7:2:1 的比例划分为训练集验证集和测试集并将图像和标签文件分别复制到相应的文件夹中
参数:
images_dir (str): 图像文件夹路径
labels_dir (str): 标签文件夹路径
output_root (str): 输出根目录
train_ratio (float): 训练集比例默认为 0.7
val_ratio (float): 验证集比例默认为 0.2
image_extensions (tuple): 支持的图像文件格式默认为 (".jpg", ".png", ".jpeg")
seed (int): 随机种子默认为 42
"""
# 设置随机种子
random.seed(seed)
# 创建输出文件夹
output_images_dir = os.path.join(output_root, "images")
output_labels_dir = os.path.join(output_root, "labels")
# 创建训练集、验证集和测试集的图像文件夹
for subset in ["train", "val", "test"]:
os.makedirs(os.path.join(output_images_dir, subset), exist_ok=True) # 图像文件夹
# 创建训练集和验证集的标签文件夹(指定路径)
os.makedirs(os.path.join(output_labels_dir, "train_original"), exist_ok=True) # 训练集标签文件夹
os.makedirs(os.path.join(output_labels_dir, "val_original"), exist_ok=True) # 验证集标签文件夹
os.makedirs(os.path.join(output_labels_dir, "test"), exist_ok=True) # 测试集标签文件夹
# 获取所有图像文件的文件名列表
image_files = [f for f in os.listdir(images_dir) if f.endswith(image_extensions)]
random.shuffle(image_files) # 随机打乱文件列表
# 计算训练集、验证集和测试集的大小
total_files = len(image_files)
train_size = int(total_files * train_ratio) # 训练集大小
val_size = int(total_files * val_ratio) # 验证集大小
test_size = total_files - train_size - val_size # 测试集大小
# 复制图像和标签文件到相应的子集文件夹中
for i, image_file in enumerate(tqdm(image_files, desc="分割数据集中")):
base_file_name = os.path.splitext(image_file)[0] # 获取文件名(不包括扩展名)
image_path = os.path.join(images_dir, image_file) # 图像文件路径
label_path = os.path.join(labels_dir, base_file_name + ".txt") # 标签文件路径
# 根据索引判断文件应复制到训练集、验证集还是测试集
if i < train_size:
# 复制到训练集
shutil.copy(image_path, os.path.join(output_images_dir, "train", image_file)) # 复制图像
shutil.copy(label_path, os.path.join(output_labels_dir, "train_original", base_file_name + ".txt")) # 复制标签
elif i < train_size + val_size:
# 复制到验证集
shutil.copy(image_path, os.path.join(output_images_dir, "val", image_file)) # 复制图像
shutil.copy(label_path, os.path.join(output_labels_dir, "val_original", base_file_name + ".txt")) # 复制标签
else:
# 复制到测试集
shutil.copy(image_path, os.path.join(output_images_dir, "test", image_file)) # 复制图像
shutil.copy(label_path, os.path.join(output_labels_dir, "test", base_file_name + ".txt")) # 复制标签
print(f"数据集分割完成!训练集: {train_size} 个样本,验证集: {val_size} 个样本,测试集: {test_size} 个样本。")
# 示例调用
if __name__ == "__main__":
# 数据集路径
images_dir = r"/home/hx/桌面/image/image" # 需要读取的所有图像文件夹路径
2025-09-01 14:14:18 +08:00
labels_dir = r"/home/hx/桌面/image/image/2" # 需要读取的所有图像与之对应的txt标签文件夹路径
2025-08-18 15:13:54 +08:00
# 输出路径
output_root = r"/home/hx/桌面/image" # 保存最终数据集的根目录
# 调用函数
split_dataset(
images_dir=images_dir,
labels_dir=labels_dir,
output_root=output_root,
train_ratio=0.7, # 训练集比例
val_ratio=0.2, # 验证集比例
image_extensions=(".jpg", ".png", ".jpeg"), # 支持的图像文件格式
seed=42, # 随机种子
)