25 lines
861 B
Python
25 lines
861 B
Python
from torchvision import datasets, transforms
|
|
from torch.utils.data import DataLoader
|
|
|
|
# 图像预处理
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
|
|
])
|
|
|
|
# 数据集路径
|
|
train_dir = 'dataset_root/train'
|
|
val_dir = 'dataset_root/val'
|
|
|
|
# 加载数据集
|
|
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
|
|
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
|
|
|
|
# DataLoader
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
|
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
|
|
|
|
print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}
|