import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init # h-swish 激活函数 class hswish(nn.Module): def forward(self, x): return x * F.relu6(x + 3, inplace=True) / 6 # SE 模块 class SE_Module(nn.Module): def __init__(self, channel, reduction=4): super(SE_Module, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) # Bneck 模块(修改:动态生成 SE_Module) class Bneck(nn.Module): def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s): super(Bneck, self).__init__() self.stride = s self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(expand_size) self.nolinear1 = nolinear self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s, padding=kernel_size // 2, groups=expand_size, bias=False) self.bn2 = nn.BatchNorm2d(expand_size) self.nolinear2 = nolinear # 动态生成 SE 模块 self.se = SE_Module(expand_size) if use_se else None self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(out_size) self.shortcut = (self.stride == 1 and in_size == out_size) def forward(self, x): out = self.nolinear1(self.bn1(self.conv1(x))) out = self.nolinear2(self.bn2(self.conv2(out))) if self.se is not None: out = self.se(out) out = self.bn3(self.conv3(out)) if self.shortcut: return x + out else: return out class MobileNetV3_Large(nn.Module): def __init__(self, num_classes=1000): super(MobileNetV3_Large, self).__init__() self.num_classes = num_classes self.init_params() # stem self.top = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(16), hswish() ) # bottlenecks(修改:use_se 参数替代 semodule) self.bneck = nn.Sequential( Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1), Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2), Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1), Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2), Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1), Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1), Bneck(3, 40, 240, 80, hswish(), False, 2), Bneck(3, 80, 200, 80, hswish(), False, 1), Bneck(3, 80, 184, 80, hswish(), False, 1), Bneck(3, 80, 184, 80, hswish(), False, 1), Bneck(3, 80, 480, 112, hswish(), True, 1), Bneck(3, 112, 672, 112, hswish(), True, 1), Bneck(5, 112, 672, 160, hswish(), True, 1), Bneck(5, 160, 672, 160, hswish(), True, 2), Bneck(5, 160, 960, 160, hswish(), True, 1), ) # final conv self.bottom = nn.Sequential( nn.Conv2d(160, 960, kernel_size=1, bias=False), nn.BatchNorm2d(960), hswish() ) # classifier self.last = nn.Sequential( nn.Linear(960, 1280), nn.BatchNorm1d(1280), hswish() ) self.linear = nn.Linear(1280, num_classes) def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.001) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, x): out = self.top(x) out = self.bneck(out) out = self.bottom(out) out = F.avg_pool2d(out, out.size(2)) # 自适应池化 out = out.view(out.size(0), -1) out = self.last(out) out = self.linear(out) return out @staticmethod @staticmethod def from_pretrained(num_classes=1000): """从 torchvision 自动下载 ImageNet 预训练,并改成指定类数""" print("Downloading official torchvision MobileNetV3-Large pretrained weights...") from torchvision.models import mobilenet_v3_large official_model = mobilenet_v3_large(pretrained=True) model = MobileNetV3_Large(num_classes=num_classes) # 初始化时设定类别数量 model.load_state_dict(official_model.state_dict(), strict=False) # 替换分类头 in_features = model.linear.in_features model.linear = nn.Linear(in_features, num_classes) # 修改为指定类数的输出 print(f"Replaced classifier head: {in_features} -> {num_classes}") return model if __name__ == "__main__": # 测试三分类模型 num_classes = 3 # 设置为你需要的类别数 model = MobileNetV3_Large.from_pretrained(num_classes=num_classes) x = torch.randn(4, 3, 224, 224) y = model(x) print(y.shape) # 应输出 [4, 3]