Files
zjsh_classification/data_load.py
2025-08-13 18:03:52 +08:00

25 lines
861 B
Python

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
])
# 数据集路径
train_dir = 'dataset_root/train'
val_dir = 'dataset_root/val'
# 加载数据集
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}