Files
zjsh_classification/class2/train_best_pt.py
2025-08-14 18:27:52 +08:00

144 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from class2.test.mobilenetv3 import MobileNetV3_Large
import os
def train(args):
# 创建输出目录
os.makedirs("checkpoints", exist_ok=True)
# 数据增强 & 预处理
transform_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train)
val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
# 模型
model = MobileNetV3_Large(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 损失函数 & 优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# ========== 改进的早停 + 保存机制 ==========
best_val_acc = 0.0
best_epoch = 0
patience = args.patience # 默认 5
wait = 0 # 连续未超越最佳的次数
print(f"开始训练,共 {args.epochs} 轮,早停耐心值: {patience}")
for epoch in range(args.epochs):
# 训练阶段
model.train()
running_loss = 0.0
correct, total = 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_acc = 100. * correct / total
avg_loss = running_loss / len(train_loader)
scheduler.step()
# 验证阶段
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_acc = 100. * correct / total
print(f"[Epoch {epoch+1:2d}/{args.epochs}] "
f"Loss: {avg_loss:.4f} | "
f"Train Acc: {train_acc:.2f}% | "
f"Val Acc: {val_acc:.2f}% | "
f"LR: {scheduler.get_last_lr()[0]:.6f}")
# ========== 核心逻辑:保存最后一个最高权重 + 改进早停 ==========
if val_acc >= best_val_acc: # ✅ 只要不低于最佳,就视为“保持高水平”
if val_acc > best_val_acc:
print(f"新的最高验证准确率: {val_acc:.2f}% (原: {best_val_acc:.2f}%)")
else:
print(f"验证准确率持平历史最佳: {val_acc:.2f}%")
# 更新最佳
best_val_acc = val_acc
best_epoch = epoch + 1
wait = 0 # ✅ 重置等待计数
# 保存完整检查点(我们想要“最后一个”最高权重)
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'val_acc': val_acc,
'best_val_acc': best_val_acc,
'args': args
}
torch.save(checkpoint, "checkpoints/mobilenetv3_last_best.pth")
print(f"✅ 已保存当前最佳权重epoch {best_epoch}")
else:
# ✅ 只有严格低于最佳才增加 wait
wait += 1
print(f"验证准确率下降 ({val_acc:.2f}% < {best_val_acc:.2f}%),已连续 {wait}/{patience} 次未超越")
if wait >= patience:
print(f"早停触发!连续 {patience} 次验证准确率未达到或超越历史最佳。")
break
# 训练结束
print(f"\n✅ 训练完成!")
print(f"最终最佳验证准确率: {best_val_acc:.2f}% (第 {best_epoch} 轮)")
print(f"最后一个最高权重已保存至: 'checkpoints/mobilenetv3_last_best.pth'")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/classdata", help="数据集路径")
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
parser.add_argument("--patience", type=int, default=5, help="早停耐心值(连续多少次未达到或超越历史最佳才停止)")
args = parser.parse_args()
train(args)