Files
zjsh_classification/class2/test/mobilenetv3.py
2025-08-14 18:27:52 +08:00

164 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# h-swish 激活函数
class hswish(nn.Module):
def forward(self, x):
return x * F.relu6(x + 3, inplace=True) / 6
# SE 模块
class SE_Module(nn.Module):
def __init__(self, channel, reduction=4):
super(SE_Module, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# Bneck 模块(修改:动态生成 SE_Module
class Bneck(nn.Module):
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s):
super(Bneck, self).__init__()
self.stride = s
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(expand_size)
self.nolinear1 = nolinear
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s,
padding=kernel_size // 2, groups=expand_size, bias=False)
self.bn2 = nn.BatchNorm2d(expand_size)
self.nolinear2 = nolinear
# 动态生成 SE 模块
self.se = SE_Module(expand_size) if use_se else None
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_size)
self.shortcut = (self.stride == 1 and in_size == out_size)
def forward(self, x):
out = self.nolinear1(self.bn1(self.conv1(x)))
out = self.nolinear2(self.bn2(self.conv2(out)))
if self.se is not None:
out = self.se(out)
out = self.bn3(self.conv3(out))
if self.shortcut:
return x + out
else:
return out
class MobileNetV3_Large(nn.Module):
def __init__(self, num_classes=1000):
super(MobileNetV3_Large, self).__init__()
self.num_classes = num_classes
self.init_params()
# stem
self.top = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(16),
hswish()
)
# bottlenecks修改use_se 参数替代 semodule
self.bneck = nn.Sequential(
Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1),
Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2),
Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1),
Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2),
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
Bneck(3, 40, 240, 80, hswish(), False, 2),
Bneck(3, 80, 200, 80, hswish(), False, 1),
Bneck(3, 80, 184, 80, hswish(), False, 1),
Bneck(3, 80, 184, 80, hswish(), False, 1),
Bneck(3, 80, 480, 112, hswish(), True, 1),
Bneck(3, 112, 672, 112, hswish(), True, 1),
Bneck(5, 112, 672, 160, hswish(), True, 1),
Bneck(5, 160, 672, 160, hswish(), True, 2),
Bneck(5, 160, 960, 160, hswish(), True, 1),
)
# final conv
self.bottom = nn.Sequential(
nn.Conv2d(160, 960, kernel_size=1, bias=False),
nn.BatchNorm2d(960),
hswish()
)
# classifier
self.last = nn.Sequential(
nn.Linear(960, 1280),
nn.BatchNorm1d(1280),
hswish()
)
self.linear = nn.Linear(1280, num_classes)
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
out = self.top(x)
out = self.bneck(out)
out = self.bottom(out)
out = F.avg_pool2d(out, out.size(2)) # 自适应池化
out = out.view(out.size(0), -1)
out = self.last(out)
out = self.linear(out)
return out
@staticmethod
def from_pretrained_for_binary():
"""从 torchvision 自动下载 ImageNet 预训练,并改成二分类"""
print("Downloading official torchvision MobileNetV3-Large pretrained weights...")
from torchvision.models import mobilenet_v3_large
official_model = mobilenet_v3_large(pretrained=True)
model = MobileNetV3_Large(num_classes=1000)
model.load_state_dict(official_model.state_dict(), strict=False)
# 替换分类头
in_features = model.linear.in_features
model.linear = nn.Linear(in_features, 2)
print(f"Replaced classifier head: {in_features} -> 2")
return model
if __name__ == "__main__":
# 测试二分类模型
model = MobileNetV3_Large.from_pretrained_for_binary()
x = torch.randn(4, 3, 224, 224)
y = model(x)
print(y.shape) # [4, 2]