from torchvision import datasets, transforms from torch.utils.data import DataLoader # 图像预处理 transform = transforms.Compose([ transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差 ]) # 数据集路径 train_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/train' val_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/val' # 加载数据集 train_dataset = datasets.ImageFolder(train_dir, transform=transform) val_dataset = datasets.ImageFolder(val_dir, transform=transform) # DataLoader train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}