feat: initial commit
This commit is contained in:
3
.idea/.gitignore
generated
vendored
Normal file
3
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
8
.idea/class.iml
generated
Normal file
8
.idea/class.iml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/class.iml" filepath="$PROJECT_DIR$/.idea/class.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
BIN
__pycache__/mobilenetv3.cpython-39.pyc
Normal file
BIN
__pycache__/mobilenetv3.cpython-39.pyc
Normal file
Binary file not shown.
43
classify.py
Normal file
43
classify.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
from mobilenetv3 import MobileNetV3_Large # 你的模型文件名
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# 1. 定义推理设备
|
||||||
|
# -------------------------------
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# 2. 创建模型并加载权重
|
||||||
|
# -------------------------------
|
||||||
|
model = MobileNetV3_Large(num_classes=2)
|
||||||
|
model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device))
|
||||||
|
model.to(device)
|
||||||
|
model.eval() # 推理模式
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# 3. 图片预处理
|
||||||
|
# -------------------------------
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
def predict(image_path):
|
||||||
|
img = Image.open(image_path).convert('RGB')
|
||||||
|
x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(x)
|
||||||
|
probs = torch.softmax(outputs, dim=1)
|
||||||
|
pred_class = torch.argmax(probs, dim=1).item()
|
||||||
|
return pred_class, probs.cpu().numpy()
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# 4. 测试推理
|
||||||
|
# -------------------------------
|
||||||
|
image_path = "2.png"
|
||||||
|
cls, prob = predict(image_path)
|
||||||
|
print(f"Predicted class: {cls}, Probabilities: {prob}")
|
||||||
67
copy_jpg_images.py
Normal file
67
copy_jpg_images.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def copy_png_images(source_folder, destination_folder):
|
||||||
|
"""
|
||||||
|
从源文件夹递归搜索所有 .png 文件并复制到目标文件夹。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
source_folder (str): 源文件夹路径
|
||||||
|
destination_folder (str): 目标文件夹路径
|
||||||
|
"""
|
||||||
|
# 将字符串路径转换为 Path 对象,更方便操作
|
||||||
|
src_path = Path(source_folder)
|
||||||
|
dest_path = Path(destination_folder)
|
||||||
|
|
||||||
|
# 检查源文件夹是否存在
|
||||||
|
if not src_path.exists():
|
||||||
|
print(f"错误:源文件夹 '{source_folder}' 不存在。")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not src_path.is_dir():
|
||||||
|
print(f"错误:'{source_folder}' 不是一个有效的文件夹。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果目标文件夹不存在,则创建它
|
||||||
|
dest_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 用于统计复制的文件数量
|
||||||
|
copied_count = 0
|
||||||
|
|
||||||
|
# 使用 rglob 递归搜索所有 .png 文件(不区分大小写)
|
||||||
|
# rglob('*.[pP][nN][gG]') 可以匹配 .png, .PNG, .Png 等
|
||||||
|
png_files = src_path.rglob('*.[jJ][pP][gG]')
|
||||||
|
|
||||||
|
for png_file in png_files:
|
||||||
|
try:
|
||||||
|
# 计算目标文件的完整路径
|
||||||
|
# 保持源文件夹的相对结构,但只保留文件名(可选:如果想保持结构,去掉 .name)
|
||||||
|
# 这里我们只复制文件名到目标文件夹,避免路径过长或结构复杂
|
||||||
|
dest_file = dest_path / png_file.name
|
||||||
|
|
||||||
|
# 处理重名文件:在文件名后添加序号
|
||||||
|
counter = 1
|
||||||
|
original_dest_file = dest_file
|
||||||
|
while dest_file.exists():
|
||||||
|
dest_file = original_dest_file.parent / f"{original_dest_file.stem}_{counter}{original_dest_file.suffix}"
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
# 执行复制
|
||||||
|
shutil.copy2(png_file, dest_file) # copy2 会保留文件的元数据(如修改时间)
|
||||||
|
print(f"已复制: {png_file} -> {dest_file}")
|
||||||
|
copied_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"复制文件时出错: {png_file} - {e}")
|
||||||
|
|
||||||
|
print(f"\n复制完成!共复制了 {copied_count} 个 PNG 文件到 '{destination_folder}'。")
|
||||||
|
|
||||||
|
# ------------------ 主程序 ------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# ====== 请在这里修改源文件夹和目标文件夹的路径 ======
|
||||||
|
source_folder = r"/home/hx/桌面/git/class/data/folder_end" # 替换为你的源文件夹路径
|
||||||
|
destination_folder = r"/home/hx/桌面/git/class/data/class0" # 替换为你的目标文件夹路径
|
||||||
|
# ====================================================
|
||||||
|
|
||||||
|
copy_png_images(source_folder, destination_folder)
|
||||||
24
data_load.py
Normal file
24
data_load.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# 图像预处理
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
|
||||||
|
])
|
||||||
|
|
||||||
|
# 数据集路径
|
||||||
|
train_dir = 'dataset_root/train'
|
||||||
|
val_dir = 'dataset_root/val'
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
|
||||||
|
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
|
||||||
|
|
||||||
|
# DataLoader
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}
|
||||||
163
mobilenetv3.py
Normal file
163
mobilenetv3.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
|
||||||
|
# h-swish 激活函数
|
||||||
|
class hswish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x * F.relu6(x + 3, inplace=True) / 6
|
||||||
|
|
||||||
|
|
||||||
|
# SE 模块
|
||||||
|
class SE_Module(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=4):
|
||||||
|
super(SE_Module, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(channel, channel // reduction, bias=False),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(channel // reduction, channel, bias=False),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, _, _ = x.size()
|
||||||
|
y = self.avg_pool(x).view(b, c)
|
||||||
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Bneck 模块(修改:动态生成 SE_Module)
|
||||||
|
class Bneck(nn.Module):
|
||||||
|
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s):
|
||||||
|
super(Bneck, self).__init__()
|
||||||
|
self.stride = s
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||||
|
self.nolinear1 = nolinear
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s,
|
||||||
|
padding=kernel_size // 2, groups=expand_size, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||||
|
self.nolinear2 = nolinear
|
||||||
|
|
||||||
|
# 动态生成 SE 模块
|
||||||
|
self.se = SE_Module(expand_size) if use_se else None
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_size)
|
||||||
|
|
||||||
|
self.shortcut = (self.stride == 1 and in_size == out_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||||
|
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||||
|
if self.se is not None:
|
||||||
|
out = self.se(out)
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
if self.shortcut:
|
||||||
|
return x + out
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3_Large(nn.Module):
|
||||||
|
def __init__(self, num_classes=1000):
|
||||||
|
super(MobileNetV3_Large, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.init_params()
|
||||||
|
|
||||||
|
# stem
|
||||||
|
self.top = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
hswish()
|
||||||
|
)
|
||||||
|
|
||||||
|
# bottlenecks(修改:use_se 参数替代 semodule)
|
||||||
|
self.bneck = nn.Sequential(
|
||||||
|
Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1),
|
||||||
|
Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2),
|
||||||
|
Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1),
|
||||||
|
|
||||||
|
Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2),
|
||||||
|
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
|
||||||
|
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
|
||||||
|
|
||||||
|
Bneck(3, 40, 240, 80, hswish(), False, 2),
|
||||||
|
Bneck(3, 80, 200, 80, hswish(), False, 1),
|
||||||
|
Bneck(3, 80, 184, 80, hswish(), False, 1),
|
||||||
|
Bneck(3, 80, 184, 80, hswish(), False, 1),
|
||||||
|
Bneck(3, 80, 480, 112, hswish(), True, 1),
|
||||||
|
Bneck(3, 112, 672, 112, hswish(), True, 1),
|
||||||
|
|
||||||
|
Bneck(5, 112, 672, 160, hswish(), True, 1),
|
||||||
|
Bneck(5, 160, 672, 160, hswish(), True, 2),
|
||||||
|
Bneck(5, 160, 960, 160, hswish(), True, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# final conv
|
||||||
|
self.bottom = nn.Sequential(
|
||||||
|
nn.Conv2d(160, 960, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(960),
|
||||||
|
hswish()
|
||||||
|
)
|
||||||
|
|
||||||
|
# classifier
|
||||||
|
self.last = nn.Sequential(
|
||||||
|
nn.Linear(960, 1280),
|
||||||
|
nn.BatchNorm1d(1280),
|
||||||
|
hswish()
|
||||||
|
)
|
||||||
|
self.linear = nn.Linear(1280, num_classes)
|
||||||
|
|
||||||
|
def init_params(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias is not None:
|
||||||
|
init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.top(x)
|
||||||
|
out = self.bneck(out)
|
||||||
|
out = self.bottom(out)
|
||||||
|
out = F.avg_pool2d(out, out.size(2)) # 自适应池化
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.last(out)
|
||||||
|
out = self.linear(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained_for_binary():
|
||||||
|
"""从 torchvision 自动下载 ImageNet 预训练,并改成二分类"""
|
||||||
|
print("Downloading official torchvision MobileNetV3-Large pretrained weights...")
|
||||||
|
from torchvision.models import mobilenet_v3_large
|
||||||
|
|
||||||
|
official_model = mobilenet_v3_large(pretrained=True)
|
||||||
|
model = MobileNetV3_Large(num_classes=1000)
|
||||||
|
model.load_state_dict(official_model.state_dict(), strict=False)
|
||||||
|
|
||||||
|
# 替换分类头
|
||||||
|
in_features = model.linear.in_features
|
||||||
|
model.linear = nn.Linear(in_features, 2)
|
||||||
|
print(f"Replaced classifier head: {in_features} -> 2")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试二分类模型
|
||||||
|
model = MobileNetV3_Large.from_pretrained_for_binary()
|
||||||
|
x = torch.randn(4, 3, 224, 224)
|
||||||
|
y = model(x)
|
||||||
|
print(y.shape) # [4, 2]
|
||||||
BIN
mobilenetv3_binary.pth
Normal file
BIN
mobilenetv3_binary.pth
Normal file
Binary file not shown.
89
train.py
Normal file
89
train.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from mobilenetv3 import MobileNetV3_Large
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
# 数据增强 & 预处理
|
||||||
|
transform_train = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
transform_val = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train)
|
||||||
|
val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# 模型
|
||||||
|
model = MobileNetV3_Large(num_classes=2)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# 损失函数 & 优化器
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
correct, total = 0, 0
|
||||||
|
|
||||||
|
for images, labels in train_loader:
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(images)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
train_acc = 100. * correct / total
|
||||||
|
print(f"[Epoch {epoch+1}/{args.epochs}] Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
model.eval()
|
||||||
|
correct, total = 0, 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, labels in val_loader:
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
outputs = model(images)
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
val_acc = 100. * correct / total
|
||||||
|
print(f"Validation Acc: {val_acc:.2f}%")
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
torch.save(model.state_dict(), "mobilenetv3_binary.pth")
|
||||||
|
print("训练完成,模型已保存到 mobilenetv3_binary.pth")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data", type=str, default="./data", help="数据集路径")
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args)
|
||||||
Reference in New Issue
Block a user