diff --git a/.idea/class.iml b/.idea/class.iml index d0876a7..539f9cf 100644 --- a/.idea/class.iml +++ b/.idea/class.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 9de2865..adde6f5 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 6c0b863..288b36b 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -2,5 +2,6 @@ + \ No newline at end of file diff --git a/class2/checkpoints/mobilenetv3_best.pth b/class2/checkpoints/mobilenetv3_best.pth new file mode 100644 index 0000000..247e4e9 Binary files /dev/null and b/class2/checkpoints/mobilenetv3_best.pth differ diff --git a/class2/checkpoints/mobilenetv3_last_best.pth b/class2/checkpoints/mobilenetv3_last_best.pth new file mode 100644 index 0000000..c77c422 Binary files /dev/null and b/class2/checkpoints/mobilenetv3_last_best.pth differ diff --git a/classify.py b/class2/classify.py similarity index 92% rename from classify.py rename to class2/classify.py index 83d6bbd..b80a20f 100644 --- a/classify.py +++ b/class2/classify.py @@ -1,7 +1,7 @@ import torch from torchvision import transforms from PIL import Image -from mobilenetv3 import MobileNetV3_Large # 你的模型文件名 +from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名 # ------------------------------- # 1. 定义推理设备 @@ -38,6 +38,6 @@ def predict(image_path): # ------------------------------- # 4. 测试推理 # ------------------------------- -image_path = "2.png" +image_path = "test/2.png" cls, prob = predict(image_path) print(f"Predicted class: {cls}, Probabilities: {prob}") diff --git a/copy_jpg_images.py b/class2/copy_jpg_images.py similarity index 100% rename from copy_jpg_images.py rename to class2/copy_jpg_images.py diff --git a/data_load.py b/class2/data_load.py similarity index 82% rename from data_load.py rename to class2/data_load.py index 6663191..bd95213 100644 --- a/data_load.py +++ b/class2/data_load.py @@ -10,8 +10,8 @@ transform = transforms.Compose([ ]) # 数据集路径 -train_dir = 'dataset_root/train' -val_dir = 'dataset_root/val' +train_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/train' +val_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/val' # 加载数据集 train_dataset = datasets.ImageFolder(train_dir, transform=transform) diff --git a/1.png b/class2/test/1.png similarity index 100% rename from 1.png rename to class2/test/1.png diff --git a/2.png b/class2/test/2.png similarity index 100% rename from 2.png rename to class2/test/2.png diff --git a/__pycache__/mobilenetv3.cpython-39.pyc b/class2/test/__pycache__/mobilenetv3.cpython-39.pyc similarity index 97% rename from __pycache__/mobilenetv3.cpython-39.pyc rename to class2/test/__pycache__/mobilenetv3.cpython-39.pyc index 515e2fd..0601b40 100644 Binary files a/__pycache__/mobilenetv3.cpython-39.pyc and b/class2/test/__pycache__/mobilenetv3.cpython-39.pyc differ diff --git a/mobilenetv3.py b/class2/test/mobilenetv3.py similarity index 100% rename from mobilenetv3.py rename to class2/test/mobilenetv3.py diff --git a/mobilenetv3_binary.pth b/class2/test/mobilenetv3_1.pth similarity index 100% rename from mobilenetv3_binary.pth rename to class2/test/mobilenetv3_1.pth diff --git a/class2/train_best_pt.py b/class2/train_best_pt.py new file mode 100644 index 0000000..667f66d --- /dev/null +++ b/class2/train_best_pt.py @@ -0,0 +1,144 @@ +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from class2.test.mobilenetv3 import MobileNetV3_Large +import os + +def train(args): + # 创建输出目录 + os.makedirs("checkpoints", exist_ok=True) + + # 数据增强 & 预处理 + transform_train = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + transform_val = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + # 加载数据集 + train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train) + val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) + + # 模型 + model = MobileNetV3_Large(num_classes=2) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # 损失函数 & 优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # 学习率调度器 + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + + # ========== 改进的早停 + 保存机制 ========== + best_val_acc = 0.0 + best_epoch = 0 + patience = args.patience # 默认 5 + wait = 0 # 连续未超越最佳的次数 + + print(f"开始训练,共 {args.epochs} 轮,早停耐心值: {patience}") + + for epoch in range(args.epochs): + # 训练阶段 + model.train() + running_loss = 0.0 + correct, total = 0, 0 + + for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + running_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + train_acc = 100. * correct / total + avg_loss = running_loss / len(train_loader) + scheduler.step() + + # 验证阶段 + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for images, labels in val_loader: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + val_acc = 100. * correct / total + + print(f"[Epoch {epoch+1:2d}/{args.epochs}] " + f"Loss: {avg_loss:.4f} | " + f"Train Acc: {train_acc:.2f}% | " + f"Val Acc: {val_acc:.2f}% | " + f"LR: {scheduler.get_last_lr()[0]:.6f}") + + # ========== 核心逻辑:保存最后一个最高权重 + 改进早停 ========== + if val_acc >= best_val_acc: # ✅ 只要不低于最佳,就视为“保持高水平” + if val_acc > best_val_acc: + print(f"新的最高验证准确率: {val_acc:.2f}% (原: {best_val_acc:.2f}%)") + else: + print(f"验证准确率持平历史最佳: {val_acc:.2f}%") + + # 更新最佳 + best_val_acc = val_acc + best_epoch = epoch + 1 + wait = 0 # ✅ 重置等待计数 + + # 保存完整检查点(我们想要“最后一个”最高权重) + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'val_acc': val_acc, + 'best_val_acc': best_val_acc, + 'args': args + } + torch.save(checkpoint, "checkpoints/mobilenetv3_last_best.pth") + print(f"✅ 已保存:当前最佳权重(epoch {best_epoch})") + + else: + # ✅ 只有严格低于最佳才增加 wait + wait += 1 + print(f"验证准确率下降 ({val_acc:.2f}% < {best_val_acc:.2f}%),已连续 {wait}/{patience} 次未超越") + + if wait >= patience: + print(f"早停触发!连续 {patience} 次验证准确率未达到或超越历史最佳。") + break + + # 训练结束 + print(f"\n✅ 训练完成!") + print(f"最终最佳验证准确率: {best_val_acc:.2f}% (第 {best_epoch} 轮)") + print(f"最后一个最高权重已保存至: 'checkpoints/mobilenetv3_last_best.pth'") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data", type=str, default="/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata", help="数据集路径") + parser.add_argument("--epochs", type=int, default=100, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=32, help="批大小") + parser.add_argument("--lr", type=float, default=1e-3, help="学习率") + parser.add_argument("--patience", type=int, default=5, help="早停耐心值(连续多少次未达到或超越历史最佳才停止)") + args = parser.parse_args() + train(args) \ No newline at end of file diff --git a/class2/train_old.py b/class2/train_old.py new file mode 100644 index 0000000..50397f6 --- /dev/null +++ b/class2/train_old.py @@ -0,0 +1,89 @@ +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from class2.test.mobilenetv3 import MobileNetV3_Large + +def train(args): + # 数据增强 & 预处理 + transform_train = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + transform_val = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + # 加载数据集 + train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train) + val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + + # 模型 + model = MobileNetV3_Large(num_classes=2) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # 损失函数 & 优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # 训练 + for epoch in range(args.epochs): + model.train() + running_loss = 0.0 + correct, total = 0, 0 + + for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + train_acc = 100. * correct / total + print(f"[Epoch {epoch+1}/{args.epochs}] Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%") + + # 验证 + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for images, labels in val_loader: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + val_acc = 100. * correct / total + print(f"Validation Acc: {val_acc:.2f}%") + + # 保存模型 + torch.save(model.state_dict(), "mobilenetv3_binary.pth") + print("训练完成,模型已保存到 mobilenetv3_binary.pth") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data", type=str, default="/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata", help="数据集路径") + parser.add_argument("--epochs", type=int, default=100, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=32, help="批大小") + parser.add_argument("--lr", type=float, default=1e-3, help="学习率") + args = parser.parse_args() + train(args) diff --git a/class3/classify.py b/class3/classify.py new file mode 100644 index 0000000..b37d53b --- /dev/null +++ b/class3/classify.py @@ -0,0 +1,43 @@ +import torch +from torchvision import transforms +from PIL import Image +from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名 + +# ------------------------------- +# 1. 定义推理设备 +# ------------------------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# ------------------------------- +# 2. 创建模型并加载权重 +# ------------------------------- +model = MobileNetV3_Large(num_classes=3) +model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device)) +model.to(device) +model.eval() # 推理模式 + +# ------------------------------- +# 3. 图片预处理 +# ------------------------------- +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) +]) + +def predict(image_path): + img = Image.open(image_path).convert('RGB') + x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度 + with torch.no_grad(): + outputs = model(x) + probs = torch.softmax(outputs, dim=1) + pred_class = torch.argmax(probs, dim=1).item() + return pred_class, probs.cpu().numpy() + +# ------------------------------- +# 4. 测试推理 +# ------------------------------- +image_path = "test/2.png" +cls, prob = predict(image_path) +print(f"Predicted class: {cls}, Probabilities: {prob}") diff --git a/class3/copy_jpg_images.py b/class3/copy_jpg_images.py new file mode 100644 index 0000000..1092410 --- /dev/null +++ b/class3/copy_jpg_images.py @@ -0,0 +1,67 @@ +import os +import shutil +from pathlib import Path + +def copy_png_images(source_folder, destination_folder): + """ + 从源文件夹递归搜索所有 .png 文件并复制到目标文件夹。 + + 参数: + source_folder (str): 源文件夹路径 + destination_folder (str): 目标文件夹路径 + """ + # 将字符串路径转换为 Path 对象,更方便操作 + src_path = Path(source_folder) + dest_path = Path(destination_folder) + + # 检查源文件夹是否存在 + if not src_path.exists(): + print(f"错误:源文件夹 '{source_folder}' 不存在。") + return + + if not src_path.is_dir(): + print(f"错误:'{source_folder}' 不是一个有效的文件夹。") + return + + # 如果目标文件夹不存在,则创建它 + dest_path.mkdir(parents=True, exist_ok=True) + + # 用于统计复制的文件数量 + copied_count = 0 + + # 使用 rglob 递归搜索所有 .png 文件(不区分大小写) + # rglob('*.[pP][nN][gG]') 可以匹配 .png, .PNG, .Png 等 + png_files = src_path.rglob('*.[jJ][pP][gG]') + + for png_file in png_files: + try: + # 计算目标文件的完整路径 + # 保持源文件夹的相对结构,但只保留文件名(可选:如果想保持结构,去掉 .name) + # 这里我们只复制文件名到目标文件夹,避免路径过长或结构复杂 + dest_file = dest_path / png_file.name + + # 处理重名文件:在文件名后添加序号 + counter = 1 + original_dest_file = dest_file + while dest_file.exists(): + dest_file = original_dest_file.parent / f"{original_dest_file.stem}_{counter}{original_dest_file.suffix}" + counter += 1 + + # 执行复制 + shutil.copy2(png_file, dest_file) # copy2 会保留文件的元数据(如修改时间) + print(f"已复制: {png_file} -> {dest_file}") + copied_count += 1 + + except Exception as e: + print(f"复制文件时出错: {png_file} - {e}") + + print(f"\n复制完成!共复制了 {copied_count} 个 PNG 文件到 '{destination_folder}'。") + +# ------------------ 主程序 ------------------ +if __name__ == "__main__": + # ====== 请在这里修改源文件夹和目标文件夹的路径 ====== + source_folder = r"/home/hx/桌面/git/class/data/folder_end" # 替换为你的源文件夹路径 + destination_folder = r"/home/hx/桌面/git/class/data/class0" # 替换为你的目标文件夹路径 + # ==================================================== + + copy_png_images(source_folder, destination_folder) \ No newline at end of file diff --git a/class3/data_load.py b/class3/data_load.py new file mode 100644 index 0000000..bd95213 --- /dev/null +++ b/class3/data_load.py @@ -0,0 +1,24 @@ +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +# 图像预处理 +transform = transforms.Compose([ + transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸 + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差 +]) + +# 数据集路径 +train_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/train' +val_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/val' + +# 加载数据集 +train_dataset = datasets.ImageFolder(train_dir, transform=transform) +val_dataset = datasets.ImageFolder(val_dir, transform=transform) + +# DataLoader +train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) + +print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1} diff --git a/class3/test/1.png b/class3/test/1.png new file mode 100644 index 0000000..8df7feb Binary files /dev/null and b/class3/test/1.png differ diff --git a/class3/test/2.png b/class3/test/2.png new file mode 100644 index 0000000..0512d78 Binary files /dev/null and b/class3/test/2.png differ diff --git a/class3/test/mobilenetv3.py b/class3/test/mobilenetv3.py new file mode 100644 index 0000000..8058c11 --- /dev/null +++ b/class3/test/mobilenetv3.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +# h-swish 激活函数 +class hswish(nn.Module): + def forward(self, x): + return x * F.relu6(x + 3, inplace=True) / 6 + + +# SE 模块 +class SE_Module(nn.Module): + def __init__(self, channel, reduction=4): + super(SE_Module, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +# Bneck 模块(修改:动态生成 SE_Module) +class Bneck(nn.Module): + def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s): + super(Bneck, self).__init__() + self.stride = s + + self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(expand_size) + self.nolinear1 = nolinear + + self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s, + padding=kernel_size // 2, groups=expand_size, bias=False) + self.bn2 = nn.BatchNorm2d(expand_size) + self.nolinear2 = nolinear + + # 动态生成 SE 模块 + self.se = SE_Module(expand_size) if use_se else None + + self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(out_size) + + self.shortcut = (self.stride == 1 and in_size == out_size) + + def forward(self, x): + out = self.nolinear1(self.bn1(self.conv1(x))) + out = self.nolinear2(self.bn2(self.conv2(out))) + if self.se is not None: + out = self.se(out) + out = self.bn3(self.conv3(out)) + if self.shortcut: + return x + out + else: + return out + + +class MobileNetV3_Large(nn.Module): + def __init__(self, num_classes=1000): + super(MobileNetV3_Large, self).__init__() + self.num_classes = num_classes + self.init_params() + + # stem + self.top = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(16), + hswish() + ) + + # bottlenecks(修改:use_se 参数替代 semodule) + self.bneck = nn.Sequential( + Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1), + Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2), + Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1), + + Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2), + Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1), + Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1), + + Bneck(3, 40, 240, 80, hswish(), False, 2), + Bneck(3, 80, 200, 80, hswish(), False, 1), + Bneck(3, 80, 184, 80, hswish(), False, 1), + Bneck(3, 80, 184, 80, hswish(), False, 1), + Bneck(3, 80, 480, 112, hswish(), True, 1), + Bneck(3, 112, 672, 112, hswish(), True, 1), + + Bneck(5, 112, 672, 160, hswish(), True, 1), + Bneck(5, 160, 672, 160, hswish(), True, 2), + Bneck(5, 160, 960, 160, hswish(), True, 1), + ) + + # final conv + self.bottom = nn.Sequential( + nn.Conv2d(160, 960, kernel_size=1, bias=False), + nn.BatchNorm2d(960), + hswish() + ) + + # classifier + self.last = nn.Sequential( + nn.Linear(960, 1280), + nn.BatchNorm1d(1280), + hswish() + ) + self.linear = nn.Linear(1280, num_classes) + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.001) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, x): + out = self.top(x) + out = self.bneck(out) + out = self.bottom(out) + out = F.avg_pool2d(out, out.size(2)) # 自适应池化 + out = out.view(out.size(0), -1) + out = self.last(out) + out = self.linear(out) + return out + + @staticmethod + @staticmethod + def from_pretrained(num_classes=1000): + """从 torchvision 自动下载 ImageNet 预训练,并改成指定类数""" + print("Downloading official torchvision MobileNetV3-Large pretrained weights...") + from torchvision.models import mobilenet_v3_large + + official_model = mobilenet_v3_large(pretrained=True) + model = MobileNetV3_Large(num_classes=num_classes) # 初始化时设定类别数量 + model.load_state_dict(official_model.state_dict(), strict=False) + + # 替换分类头 + in_features = model.linear.in_features + model.linear = nn.Linear(in_features, num_classes) # 修改为指定类数的输出 + print(f"Replaced classifier head: {in_features} -> {num_classes}") + return model + +if __name__ == "__main__": + # 测试三分类模型 + num_classes = 3 # 设置为你需要的类别数 + model = MobileNetV3_Large.from_pretrained(num_classes=num_classes) + x = torch.randn(4, 3, 224, 224) + y = model(x) + print(y.shape) # 应输出 [4, 3] diff --git a/class3/test/mobilenetv3_binary.pth b/class3/test/mobilenetv3_binary.pth new file mode 100644 index 0000000..56711ee Binary files /dev/null and b/class3/test/mobilenetv3_binary.pth differ diff --git a/train.py b/class3/train.py similarity index 95% rename from train.py rename to class3/train.py index dcf2aca..0225103 100644 --- a/train.py +++ b/class3/train.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader -from mobilenetv3 import MobileNetV3_Large +from class3.test.mobilenetv3 import MobileNetV3_Large def train(args): # 数据增强 & 预处理 @@ -29,8 +29,8 @@ def train(args): train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) - # 模型 - model = MobileNetV3_Large(num_classes=2) + # 模型 更改分类类别数量 + model = MobileNetV3_Large(num_classes=3) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device)