加一个三分类
This commit is contained in:
2
.idea/class.iml
generated
2
.idea/class.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="jdk" jdkName="yolov11" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.10" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="yolov11" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
1
.idea/vcs.xml
generated
1
.idea/vcs.xml
generated
@ -2,5 +2,6 @@
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
BIN
class2/checkpoints/mobilenetv3_best.pth
Normal file
BIN
class2/checkpoints/mobilenetv3_best.pth
Normal file
Binary file not shown.
BIN
class2/checkpoints/mobilenetv3_last_best.pth
Normal file
BIN
class2/checkpoints/mobilenetv3_last_best.pth
Normal file
Binary file not shown.
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from mobilenetv3 import MobileNetV3_Large # 你的模型文件名
|
||||
from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名
|
||||
|
||||
# -------------------------------
|
||||
# 1. 定义推理设备
|
||||
@ -38,6 +38,6 @@ def predict(image_path):
|
||||
# -------------------------------
|
||||
# 4. 测试推理
|
||||
# -------------------------------
|
||||
image_path = "2.png"
|
||||
image_path = "test/2.png"
|
||||
cls, prob = predict(image_path)
|
||||
print(f"Predicted class: {cls}, Probabilities: {prob}")
|
||||
@ -10,8 +10,8 @@ transform = transforms.Compose([
|
||||
])
|
||||
|
||||
# 数据集路径
|
||||
train_dir = 'dataset_root/train'
|
||||
val_dir = 'dataset_root/val'
|
||||
train_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/train'
|
||||
val_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/val'
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
|
||||
|
Before Width: | Height: | Size: 3.9 MiB After Width: | Height: | Size: 3.9 MiB |
|
Before Width: | Height: | Size: 6.7 MiB After Width: | Height: | Size: 6.7 MiB |
Binary file not shown.
144
class2/train_best_pt.py
Normal file
144
class2/train_best_pt.py
Normal file
@ -0,0 +1,144 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from class2.test.mobilenetv3 import MobileNetV3_Large
|
||||
import os
|
||||
|
||||
def train(args):
|
||||
# 创建输出目录
|
||||
os.makedirs("checkpoints", exist_ok=True)
|
||||
|
||||
# 数据增强 & 预处理
|
||||
transform_train = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
transform_val = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train)
|
||||
val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||
|
||||
# 模型
|
||||
model = MobileNetV3_Large(num_classes=2)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
# 损失函数 & 优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# 学习率调度器
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
# ========== 改进的早停 + 保存机制 ==========
|
||||
best_val_acc = 0.0
|
||||
best_epoch = 0
|
||||
patience = args.patience # 默认 5
|
||||
wait = 0 # 连续未超越最佳的次数
|
||||
|
||||
print(f"开始训练,共 {args.epochs} 轮,早停耐心值: {patience}")
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
# 训练阶段
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct, total = 0, 0
|
||||
|
||||
for images, labels in train_loader:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
train_acc = 100. * correct / total
|
||||
avg_loss = running_loss / len(train_loader)
|
||||
scheduler.step()
|
||||
|
||||
# 验证阶段
|
||||
model.eval()
|
||||
correct, total = 0, 0
|
||||
with torch.no_grad():
|
||||
for images, labels in val_loader:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
val_acc = 100. * correct / total
|
||||
|
||||
print(f"[Epoch {epoch+1:2d}/{args.epochs}] "
|
||||
f"Loss: {avg_loss:.4f} | "
|
||||
f"Train Acc: {train_acc:.2f}% | "
|
||||
f"Val Acc: {val_acc:.2f}% | "
|
||||
f"LR: {scheduler.get_last_lr()[0]:.6f}")
|
||||
|
||||
# ========== 核心逻辑:保存最后一个最高权重 + 改进早停 ==========
|
||||
if val_acc >= best_val_acc: # ✅ 只要不低于最佳,就视为“保持高水平”
|
||||
if val_acc > best_val_acc:
|
||||
print(f"新的最高验证准确率: {val_acc:.2f}% (原: {best_val_acc:.2f}%)")
|
||||
else:
|
||||
print(f"验证准确率持平历史最佳: {val_acc:.2f}%")
|
||||
|
||||
# 更新最佳
|
||||
best_val_acc = val_acc
|
||||
best_epoch = epoch + 1
|
||||
wait = 0 # ✅ 重置等待计数
|
||||
|
||||
# 保存完整检查点(我们想要“最后一个”最高权重)
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'val_acc': val_acc,
|
||||
'best_val_acc': best_val_acc,
|
||||
'args': args
|
||||
}
|
||||
torch.save(checkpoint, "checkpoints/mobilenetv3_last_best.pth")
|
||||
print(f"✅ 已保存:当前最佳权重(epoch {best_epoch})")
|
||||
|
||||
else:
|
||||
# ✅ 只有严格低于最佳才增加 wait
|
||||
wait += 1
|
||||
print(f"验证准确率下降 ({val_acc:.2f}% < {best_val_acc:.2f}%),已连续 {wait}/{patience} 次未超越")
|
||||
|
||||
if wait >= patience:
|
||||
print(f"早停触发!连续 {patience} 次验证准确率未达到或超越历史最佳。")
|
||||
break
|
||||
|
||||
# 训练结束
|
||||
print(f"\n✅ 训练完成!")
|
||||
print(f"最终最佳验证准确率: {best_val_acc:.2f}% (第 {best_epoch} 轮)")
|
||||
print(f"最后一个最高权重已保存至: 'checkpoints/mobilenetv3_last_best.pth'")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", type=str, default="/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata", help="数据集路径")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
|
||||
parser.add_argument("--patience", type=int, default=5, help="早停耐心值(连续多少次未达到或超越历史最佳才停止)")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
89
class2/train_old.py
Normal file
89
class2/train_old.py
Normal file
@ -0,0 +1,89 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from class2.test.mobilenetv3 import MobileNetV3_Large
|
||||
|
||||
def train(args):
|
||||
# 数据增强 & 预处理
|
||||
transform_train = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
transform_val = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train)
|
||||
val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
# 模型
|
||||
model = MobileNetV3_Large(num_classes=2)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
# 损失函数 & 优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# 训练
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct, total = 0, 0
|
||||
|
||||
for images, labels in train_loader:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
train_acc = 100. * correct / total
|
||||
print(f"[Epoch {epoch+1}/{args.epochs}] Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%")
|
||||
|
||||
# 验证
|
||||
model.eval()
|
||||
correct, total = 0, 0
|
||||
with torch.no_grad():
|
||||
for images, labels in val_loader:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
val_acc = 100. * correct / total
|
||||
print(f"Validation Acc: {val_acc:.2f}%")
|
||||
|
||||
# 保存模型
|
||||
torch.save(model.state_dict(), "mobilenetv3_binary.pth")
|
||||
print("训练完成,模型已保存到 mobilenetv3_binary.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", type=str, default="/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata", help="数据集路径")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
43
class3/classify.py
Normal file
43
class3/classify.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名
|
||||
|
||||
# -------------------------------
|
||||
# 1. 定义推理设备
|
||||
# -------------------------------
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# -------------------------------
|
||||
# 2. 创建模型并加载权重
|
||||
# -------------------------------
|
||||
model = MobileNetV3_Large(num_classes=3)
|
||||
model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device))
|
||||
model.to(device)
|
||||
model.eval() # 推理模式
|
||||
|
||||
# -------------------------------
|
||||
# 3. 图片预处理
|
||||
# -------------------------------
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def predict(image_path):
|
||||
img = Image.open(image_path).convert('RGB')
|
||||
x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度
|
||||
with torch.no_grad():
|
||||
outputs = model(x)
|
||||
probs = torch.softmax(outputs, dim=1)
|
||||
pred_class = torch.argmax(probs, dim=1).item()
|
||||
return pred_class, probs.cpu().numpy()
|
||||
|
||||
# -------------------------------
|
||||
# 4. 测试推理
|
||||
# -------------------------------
|
||||
image_path = "test/2.png"
|
||||
cls, prob = predict(image_path)
|
||||
print(f"Predicted class: {cls}, Probabilities: {prob}")
|
||||
67
class3/copy_jpg_images.py
Normal file
67
class3/copy_jpg_images.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
def copy_png_images(source_folder, destination_folder):
|
||||
"""
|
||||
从源文件夹递归搜索所有 .png 文件并复制到目标文件夹。
|
||||
|
||||
参数:
|
||||
source_folder (str): 源文件夹路径
|
||||
destination_folder (str): 目标文件夹路径
|
||||
"""
|
||||
# 将字符串路径转换为 Path 对象,更方便操作
|
||||
src_path = Path(source_folder)
|
||||
dest_path = Path(destination_folder)
|
||||
|
||||
# 检查源文件夹是否存在
|
||||
if not src_path.exists():
|
||||
print(f"错误:源文件夹 '{source_folder}' 不存在。")
|
||||
return
|
||||
|
||||
if not src_path.is_dir():
|
||||
print(f"错误:'{source_folder}' 不是一个有效的文件夹。")
|
||||
return
|
||||
|
||||
# 如果目标文件夹不存在,则创建它
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 用于统计复制的文件数量
|
||||
copied_count = 0
|
||||
|
||||
# 使用 rglob 递归搜索所有 .png 文件(不区分大小写)
|
||||
# rglob('*.[pP][nN][gG]') 可以匹配 .png, .PNG, .Png 等
|
||||
png_files = src_path.rglob('*.[jJ][pP][gG]')
|
||||
|
||||
for png_file in png_files:
|
||||
try:
|
||||
# 计算目标文件的完整路径
|
||||
# 保持源文件夹的相对结构,但只保留文件名(可选:如果想保持结构,去掉 .name)
|
||||
# 这里我们只复制文件名到目标文件夹,避免路径过长或结构复杂
|
||||
dest_file = dest_path / png_file.name
|
||||
|
||||
# 处理重名文件:在文件名后添加序号
|
||||
counter = 1
|
||||
original_dest_file = dest_file
|
||||
while dest_file.exists():
|
||||
dest_file = original_dest_file.parent / f"{original_dest_file.stem}_{counter}{original_dest_file.suffix}"
|
||||
counter += 1
|
||||
|
||||
# 执行复制
|
||||
shutil.copy2(png_file, dest_file) # copy2 会保留文件的元数据(如修改时间)
|
||||
print(f"已复制: {png_file} -> {dest_file}")
|
||||
copied_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"复制文件时出错: {png_file} - {e}")
|
||||
|
||||
print(f"\n复制完成!共复制了 {copied_count} 个 PNG 文件到 '{destination_folder}'。")
|
||||
|
||||
# ------------------ 主程序 ------------------
|
||||
if __name__ == "__main__":
|
||||
# ====== 请在这里修改源文件夹和目标文件夹的路径 ======
|
||||
source_folder = r"/home/hx/桌面/git/class/data/folder_end" # 替换为你的源文件夹路径
|
||||
destination_folder = r"/home/hx/桌面/git/class/data/class0" # 替换为你的目标文件夹路径
|
||||
# ====================================================
|
||||
|
||||
copy_png_images(source_folder, destination_folder)
|
||||
24
class3/data_load.py
Normal file
24
class3/data_load.py
Normal file
@ -0,0 +1,24 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# 图像预处理
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
|
||||
])
|
||||
|
||||
# 数据集路径
|
||||
train_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/train'
|
||||
val_dir = '/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata/val'
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
|
||||
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
|
||||
|
||||
# DataLoader
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
|
||||
|
||||
print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}
|
||||
BIN
class3/test/1.png
Normal file
BIN
class3/test/1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.9 MiB |
BIN
class3/test/2.png
Normal file
BIN
class3/test/2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.7 MiB |
164
class3/test/mobilenetv3.py
Normal file
164
class3/test/mobilenetv3.py
Normal file
@ -0,0 +1,164 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
# h-swish 激活函数
|
||||
class hswish(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * F.relu6(x + 3, inplace=True) / 6
|
||||
|
||||
|
||||
# SE 模块
|
||||
class SE_Module(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SE_Module, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
# Bneck 模块(修改:动态生成 SE_Module)
|
||||
class Bneck(nn.Module):
|
||||
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s):
|
||||
super(Bneck, self).__init__()
|
||||
self.stride = s
|
||||
|
||||
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear1 = nolinear
|
||||
|
||||
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s,
|
||||
padding=kernel_size // 2, groups=expand_size, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear2 = nolinear
|
||||
|
||||
# 动态生成 SE 模块
|
||||
self.se = SE_Module(expand_size) if use_se else None
|
||||
|
||||
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_size)
|
||||
|
||||
self.shortcut = (self.stride == 1 and in_size == out_size)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.shortcut:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV3_Large(nn.Module):
|
||||
def __init__(self, num_classes=1000):
|
||||
super(MobileNetV3_Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.init_params()
|
||||
|
||||
# stem
|
||||
self.top = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(16),
|
||||
hswish()
|
||||
)
|
||||
|
||||
# bottlenecks(修改:use_se 参数替代 semodule)
|
||||
self.bneck = nn.Sequential(
|
||||
Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1),
|
||||
Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2),
|
||||
Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1),
|
||||
|
||||
Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2),
|
||||
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
|
||||
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
|
||||
|
||||
Bneck(3, 40, 240, 80, hswish(), False, 2),
|
||||
Bneck(3, 80, 200, 80, hswish(), False, 1),
|
||||
Bneck(3, 80, 184, 80, hswish(), False, 1),
|
||||
Bneck(3, 80, 184, 80, hswish(), False, 1),
|
||||
Bneck(3, 80, 480, 112, hswish(), True, 1),
|
||||
Bneck(3, 112, 672, 112, hswish(), True, 1),
|
||||
|
||||
Bneck(5, 112, 672, 160, hswish(), True, 1),
|
||||
Bneck(5, 160, 672, 160, hswish(), True, 2),
|
||||
Bneck(5, 160, 960, 160, hswish(), True, 1),
|
||||
)
|
||||
|
||||
# final conv
|
||||
self.bottom = nn.Sequential(
|
||||
nn.Conv2d(160, 960, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(960),
|
||||
hswish()
|
||||
)
|
||||
|
||||
# classifier
|
||||
self.last = nn.Sequential(
|
||||
nn.Linear(960, 1280),
|
||||
nn.BatchNorm1d(1280),
|
||||
hswish()
|
||||
)
|
||||
self.linear = nn.Linear(1280, num_classes)
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.top(x)
|
||||
out = self.bneck(out)
|
||||
out = self.bottom(out)
|
||||
out = F.avg_pool2d(out, out.size(2)) # 自适应池化
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.last(out)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@staticmethod
|
||||
def from_pretrained(num_classes=1000):
|
||||
"""从 torchvision 自动下载 ImageNet 预训练,并改成指定类数"""
|
||||
print("Downloading official torchvision MobileNetV3-Large pretrained weights...")
|
||||
from torchvision.models import mobilenet_v3_large
|
||||
|
||||
official_model = mobilenet_v3_large(pretrained=True)
|
||||
model = MobileNetV3_Large(num_classes=num_classes) # 初始化时设定类别数量
|
||||
model.load_state_dict(official_model.state_dict(), strict=False)
|
||||
|
||||
# 替换分类头
|
||||
in_features = model.linear.in_features
|
||||
model.linear = nn.Linear(in_features, num_classes) # 修改为指定类数的输出
|
||||
print(f"Replaced classifier head: {in_features} -> {num_classes}")
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试三分类模型
|
||||
num_classes = 3 # 设置为你需要的类别数
|
||||
model = MobileNetV3_Large.from_pretrained(num_classes=num_classes)
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
y = model(x)
|
||||
print(y.shape) # 应输出 [4, 3]
|
||||
BIN
class3/test/mobilenetv3_binary.pth
Normal file
BIN
class3/test/mobilenetv3_binary.pth
Normal file
Binary file not shown.
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from mobilenetv3 import MobileNetV3_Large
|
||||
from class3.test.mobilenetv3 import MobileNetV3_Large
|
||||
|
||||
def train(args):
|
||||
# 数据增强 & 预处理
|
||||
@ -29,8 +29,8 @@ def train(args):
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
# 模型
|
||||
model = MobileNetV3_Large(num_classes=2)
|
||||
# 模型 更改分类类别数量
|
||||
model = MobileNetV3_Large(num_classes=3)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user