加一个三分类
This commit is contained in:
24
class2/data_load.py
Normal file
24
class2/data_load.py
Normal file
@ -0,0 +1,24 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# 图像预处理
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
|
||||
])
|
||||
|
||||
# 数据集路径
|
||||
train_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/train'
|
||||
val_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/val'
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
|
||||
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
|
||||
|
||||
# DataLoader
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
|
||||
|
||||
print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}
|
||||
Reference in New Issue
Block a user