Files
zjsh_yolov11/yolo11_obb/divid_dataset.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

84 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
import random
def split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1):
"""
将 train 目录中的数据按比例划分为 train/val/test 三部分。
val 和 test 各自获得不同的一部分数据。
"""
# 检查目录
if not os.path.exists(train_dir):
raise FileNotFoundError(f"训练目录不存在: {train_dir}")
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# 支持的图片格式
img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
# 获取所有图片文件
all_files = os.listdir(train_dir)
image_files = [f for f in all_files if os.path.splitext(f.lower())[1] in img_extensions]
# 配对图片和标签
pairs = []
for img in image_files:
name_no_ext = os.path.splitext(img)[0]
txt = name_no_ext + '.txt'
if txt in all_files:
pairs.append((img, txt))
else:
print(f"⚠️ 忽略 {img}:缺少对应标签文件")
if len(pairs) == 0:
print("❌ 未找到有效数据对")
return
total = len(pairs)
num_test = int(total * test_ratio)
num_val = int(total * val_ratio)
print(f"✅ 共找到 {total} 组有效数据")
print(f"✅ 将移动val={num_val}, test={num_test}")
# 打乱并抽样
random.shuffle(pairs)
test_sample = pairs[:num_test]
val_sample = pairs[num_test:num_test + num_val]
# 剩下的留在 train
# 移动到 test
for img, txt in test_sample:
try:
shutil.move(os.path.join(train_dir, img), os.path.join(test_dir, img))
shutil.move(os.path.join(train_dir, txt), os.path.join(test_dir, txt))
except Exception as e:
print(f"❌ 移动到 test 失败: {img}, {e}")
# 移动到 val
for img, txt in val_sample:
try:
shutil.move(os.path.join(train_dir, img), os.path.join(val_dir, img))
shutil.move(os.path.join(train_dir, txt), os.path.join(val_dir, txt))
except Exception as e:
print(f"❌ 移动到 val 失败: {img}, {e}")
print(f"\n✅ 划分完成!")
print(f" train 保留: {total - num_val - num_test}")
print(f" val : {num_val}")
print(f" test : {num_test}")
print(f" 所有 val 和 test 数据已从 train 移出。")
# ========================
# 使用示例
# ========================
if __name__ == "__main__":
train_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/train"
val_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/val"
test_dir = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/obb3/test"
split_train_val_test(train_dir, val_dir, test_dir, val_ratio=0.1, test_ratio=0.1)
# ========================