feat: initial commit

This commit is contained in:
琉璃月光
2025-08-13 18:03:52 +08:00
commit c1b5481a29
15 changed files with 424 additions and 0 deletions

3
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

8
.idea/class.iml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/class.iml" filepath="$PROJECT_DIR$/.idea/class.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

BIN
1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 MiB

BIN
2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.7 MiB

Binary file not shown.

43
classify.py Normal file
View File

@ -0,0 +1,43 @@
import torch
from torchvision import transforms
from PIL import Image
from mobilenetv3 import MobileNetV3_Large # 你的模型文件名
# -------------------------------
# 1. 定义推理设备
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------
# 2. 创建模型并加载权重
# -------------------------------
model = MobileNetV3_Large(num_classes=2)
model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device))
model.to(device)
model.eval() # 推理模式
# -------------------------------
# 3. 图片预处理
# -------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict(image_path):
img = Image.open(image_path).convert('RGB')
x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度
with torch.no_grad():
outputs = model(x)
probs = torch.softmax(outputs, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
return pred_class, probs.cpu().numpy()
# -------------------------------
# 4. 测试推理
# -------------------------------
image_path = "2.png"
cls, prob = predict(image_path)
print(f"Predicted class: {cls}, Probabilities: {prob}")

67
copy_jpg_images.py Normal file
View File

@ -0,0 +1,67 @@
import os
import shutil
from pathlib import Path
def copy_png_images(source_folder, destination_folder):
"""
从源文件夹递归搜索所有 .png 文件并复制到目标文件夹。
参数:
source_folder (str): 源文件夹路径
destination_folder (str): 目标文件夹路径
"""
# 将字符串路径转换为 Path 对象,更方便操作
src_path = Path(source_folder)
dest_path = Path(destination_folder)
# 检查源文件夹是否存在
if not src_path.exists():
print(f"错误:源文件夹 '{source_folder}' 不存在。")
return
if not src_path.is_dir():
print(f"错误:'{source_folder}' 不是一个有效的文件夹。")
return
# 如果目标文件夹不存在,则创建它
dest_path.mkdir(parents=True, exist_ok=True)
# 用于统计复制的文件数量
copied_count = 0
# 使用 rglob 递归搜索所有 .png 文件(不区分大小写)
# rglob('*.[pP][nN][gG]') 可以匹配 .png, .PNG, .Png 等
png_files = src_path.rglob('*.[jJ][pP][gG]')
for png_file in png_files:
try:
# 计算目标文件的完整路径
# 保持源文件夹的相对结构,但只保留文件名(可选:如果想保持结构,去掉 .name
# 这里我们只复制文件名到目标文件夹,避免路径过长或结构复杂
dest_file = dest_path / png_file.name
# 处理重名文件:在文件名后添加序号
counter = 1
original_dest_file = dest_file
while dest_file.exists():
dest_file = original_dest_file.parent / f"{original_dest_file.stem}_{counter}{original_dest_file.suffix}"
counter += 1
# 执行复制
shutil.copy2(png_file, dest_file) # copy2 会保留文件的元数据(如修改时间)
print(f"已复制: {png_file} -> {dest_file}")
copied_count += 1
except Exception as e:
print(f"复制文件时出错: {png_file} - {e}")
print(f"\n复制完成!共复制了 {copied_count} 个 PNG 文件到 '{destination_folder}'")
# ------------------ 主程序 ------------------
if __name__ == "__main__":
# ====== 请在这里修改源文件夹和目标文件夹的路径 ======
source_folder = r"/home/hx/桌面/git/class/data/folder_end" # 替换为你的源文件夹路径
destination_folder = r"/home/hx/桌面/git/class/data/class0" # 替换为你的目标文件夹路径
# ====================================================
copy_png_images(source_folder, destination_folder)

24
data_load.py Normal file
View File

@ -0,0 +1,24 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
])
# 数据集路径
train_dir = 'dataset_root/train'
val_dir = 'dataset_root/val'
# 加载数据集
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}

163
mobilenetv3.py Normal file
View File

@ -0,0 +1,163 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# h-swish 激活函数
class hswish(nn.Module):
def forward(self, x):
return x * F.relu6(x + 3, inplace=True) / 6
# SE 模块
class SE_Module(nn.Module):
def __init__(self, channel, reduction=4):
super(SE_Module, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# Bneck 模块(修改:动态生成 SE_Module
class Bneck(nn.Module):
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s):
super(Bneck, self).__init__()
self.stride = s
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(expand_size)
self.nolinear1 = nolinear
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s,
padding=kernel_size // 2, groups=expand_size, bias=False)
self.bn2 = nn.BatchNorm2d(expand_size)
self.nolinear2 = nolinear
# 动态生成 SE 模块
self.se = SE_Module(expand_size) if use_se else None
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_size)
self.shortcut = (self.stride == 1 and in_size == out_size)
def forward(self, x):
out = self.nolinear1(self.bn1(self.conv1(x)))
out = self.nolinear2(self.bn2(self.conv2(out)))
if self.se is not None:
out = self.se(out)
out = self.bn3(self.conv3(out))
if self.shortcut:
return x + out
else:
return out
class MobileNetV3_Large(nn.Module):
def __init__(self, num_classes=1000):
super(MobileNetV3_Large, self).__init__()
self.num_classes = num_classes
self.init_params()
# stem
self.top = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(16),
hswish()
)
# bottlenecks修改use_se 参数替代 semodule
self.bneck = nn.Sequential(
Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1),
Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2),
Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1),
Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2),
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
Bneck(3, 40, 240, 80, hswish(), False, 2),
Bneck(3, 80, 200, 80, hswish(), False, 1),
Bneck(3, 80, 184, 80, hswish(), False, 1),
Bneck(3, 80, 184, 80, hswish(), False, 1),
Bneck(3, 80, 480, 112, hswish(), True, 1),
Bneck(3, 112, 672, 112, hswish(), True, 1),
Bneck(5, 112, 672, 160, hswish(), True, 1),
Bneck(5, 160, 672, 160, hswish(), True, 2),
Bneck(5, 160, 960, 160, hswish(), True, 1),
)
# final conv
self.bottom = nn.Sequential(
nn.Conv2d(160, 960, kernel_size=1, bias=False),
nn.BatchNorm2d(960),
hswish()
)
# classifier
self.last = nn.Sequential(
nn.Linear(960, 1280),
nn.BatchNorm1d(1280),
hswish()
)
self.linear = nn.Linear(1280, num_classes)
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
out = self.top(x)
out = self.bneck(out)
out = self.bottom(out)
out = F.avg_pool2d(out, out.size(2)) # 自适应池化
out = out.view(out.size(0), -1)
out = self.last(out)
out = self.linear(out)
return out
@staticmethod
def from_pretrained_for_binary():
"""从 torchvision 自动下载 ImageNet 预训练,并改成二分类"""
print("Downloading official torchvision MobileNetV3-Large pretrained weights...")
from torchvision.models import mobilenet_v3_large
official_model = mobilenet_v3_large(pretrained=True)
model = MobileNetV3_Large(num_classes=1000)
model.load_state_dict(official_model.state_dict(), strict=False)
# 替换分类头
in_features = model.linear.in_features
model.linear = nn.Linear(in_features, 2)
print(f"Replaced classifier head: {in_features} -> 2")
return model
if __name__ == "__main__":
# 测试二分类模型
model = MobileNetV3_Large.from_pretrained_for_binary()
x = torch.randn(4, 3, 224, 224)
y = model(x)
print(y.shape) # [4, 2]

BIN
mobilenetv3_binary.pth Normal file

Binary file not shown.

89
train.py Normal file
View File

@ -0,0 +1,89 @@
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from mobilenetv3 import MobileNetV3_Large
def train(args):
# 数据增强 & 预处理
transform_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train)
val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# 模型
model = MobileNetV3_Large(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 损失函数 & 优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# 训练
for epoch in range(args.epochs):
model.train()
running_loss = 0.0
correct, total = 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_acc = 100. * correct / total
print(f"[Epoch {epoch+1}/{args.epochs}] Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%")
# 验证
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_acc = 100. * correct / total
print(f"Validation Acc: {val_acc:.2f}%")
# 保存模型
torch.save(model.state_dict(), "mobilenetv3_binary.pth")
print("训练完成,模型已保存到 mobilenetv3_binary.pth")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="./data", help="数据集路径")
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
args = parser.parse_args()
train(args)