commit c1b5481a2957a96990b26a57da3cd4edcf466509
Author: 琉璃月光 <15630071+llyg777@user.noreply.gitee.com>
Date: Wed Aug 13 18:03:52 2025 +0800
feat: initial commit
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..359bb53
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
diff --git a/.idea/class.iml b/.idea/class.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/class.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..9de2865
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..247835b
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..6c0b863
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/1.png b/1.png
new file mode 100644
index 0000000..8df7feb
Binary files /dev/null and b/1.png differ
diff --git a/2.png b/2.png
new file mode 100644
index 0000000..0512d78
Binary files /dev/null and b/2.png differ
diff --git a/__pycache__/mobilenetv3.cpython-39.pyc b/__pycache__/mobilenetv3.cpython-39.pyc
new file mode 100644
index 0000000..515e2fd
Binary files /dev/null and b/__pycache__/mobilenetv3.cpython-39.pyc differ
diff --git a/classify.py b/classify.py
new file mode 100644
index 0000000..83d6bbd
--- /dev/null
+++ b/classify.py
@@ -0,0 +1,43 @@
+import torch
+from torchvision import transforms
+from PIL import Image
+from mobilenetv3 import MobileNetV3_Large # 你的模型文件名
+
+# -------------------------------
+# 1. 定义推理设备
+# -------------------------------
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# -------------------------------
+# 2. 创建模型并加载权重
+# -------------------------------
+model = MobileNetV3_Large(num_classes=2)
+model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device))
+model.to(device)
+model.eval() # 推理模式
+
+# -------------------------------
+# 3. 图片预处理
+# -------------------------------
+transform = transforms.Compose([
+ transforms.Resize((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+])
+
+def predict(image_path):
+ img = Image.open(image_path).convert('RGB')
+ x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度
+ with torch.no_grad():
+ outputs = model(x)
+ probs = torch.softmax(outputs, dim=1)
+ pred_class = torch.argmax(probs, dim=1).item()
+ return pred_class, probs.cpu().numpy()
+
+# -------------------------------
+# 4. 测试推理
+# -------------------------------
+image_path = "2.png"
+cls, prob = predict(image_path)
+print(f"Predicted class: {cls}, Probabilities: {prob}")
diff --git a/copy_jpg_images.py b/copy_jpg_images.py
new file mode 100644
index 0000000..1092410
--- /dev/null
+++ b/copy_jpg_images.py
@@ -0,0 +1,67 @@
+import os
+import shutil
+from pathlib import Path
+
+def copy_png_images(source_folder, destination_folder):
+ """
+ 从源文件夹递归搜索所有 .png 文件并复制到目标文件夹。
+
+ 参数:
+ source_folder (str): 源文件夹路径
+ destination_folder (str): 目标文件夹路径
+ """
+ # 将字符串路径转换为 Path 对象,更方便操作
+ src_path = Path(source_folder)
+ dest_path = Path(destination_folder)
+
+ # 检查源文件夹是否存在
+ if not src_path.exists():
+ print(f"错误:源文件夹 '{source_folder}' 不存在。")
+ return
+
+ if not src_path.is_dir():
+ print(f"错误:'{source_folder}' 不是一个有效的文件夹。")
+ return
+
+ # 如果目标文件夹不存在,则创建它
+ dest_path.mkdir(parents=True, exist_ok=True)
+
+ # 用于统计复制的文件数量
+ copied_count = 0
+
+ # 使用 rglob 递归搜索所有 .png 文件(不区分大小写)
+ # rglob('*.[pP][nN][gG]') 可以匹配 .png, .PNG, .Png 等
+ png_files = src_path.rglob('*.[jJ][pP][gG]')
+
+ for png_file in png_files:
+ try:
+ # 计算目标文件的完整路径
+ # 保持源文件夹的相对结构,但只保留文件名(可选:如果想保持结构,去掉 .name)
+ # 这里我们只复制文件名到目标文件夹,避免路径过长或结构复杂
+ dest_file = dest_path / png_file.name
+
+ # 处理重名文件:在文件名后添加序号
+ counter = 1
+ original_dest_file = dest_file
+ while dest_file.exists():
+ dest_file = original_dest_file.parent / f"{original_dest_file.stem}_{counter}{original_dest_file.suffix}"
+ counter += 1
+
+ # 执行复制
+ shutil.copy2(png_file, dest_file) # copy2 会保留文件的元数据(如修改时间)
+ print(f"已复制: {png_file} -> {dest_file}")
+ copied_count += 1
+
+ except Exception as e:
+ print(f"复制文件时出错: {png_file} - {e}")
+
+ print(f"\n复制完成!共复制了 {copied_count} 个 PNG 文件到 '{destination_folder}'。")
+
+# ------------------ 主程序 ------------------
+if __name__ == "__main__":
+ # ====== 请在这里修改源文件夹和目标文件夹的路径 ======
+ source_folder = r"/home/hx/桌面/git/class/data/folder_end" # 替换为你的源文件夹路径
+ destination_folder = r"/home/hx/桌面/git/class/data/class0" # 替换为你的目标文件夹路径
+ # ====================================================
+
+ copy_png_images(source_folder, destination_folder)
\ No newline at end of file
diff --git a/data_load.py b/data_load.py
new file mode 100644
index 0000000..6663191
--- /dev/null
+++ b/data_load.py
@@ -0,0 +1,24 @@
+from torchvision import datasets, transforms
+from torch.utils.data import DataLoader
+
+# 图像预处理
+transform = transforms.Compose([
+ transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差
+])
+
+# 数据集路径
+train_dir = 'dataset_root/train'
+val_dir = 'dataset_root/val'
+
+# 加载数据集
+train_dataset = datasets.ImageFolder(train_dir, transform=transform)
+val_dataset = datasets.ImageFolder(val_dir, transform=transform)
+
+# DataLoader
+train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
+val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
+
+print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1}
diff --git a/mobilenetv3.py b/mobilenetv3.py
new file mode 100644
index 0000000..6fb6574
--- /dev/null
+++ b/mobilenetv3.py
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+
+# h-swish 激活函数
+class hswish(nn.Module):
+ def forward(self, x):
+ return x * F.relu6(x + 3, inplace=True) / 6
+
+
+# SE 模块
+class SE_Module(nn.Module):
+ def __init__(self, channel, reduction=4):
+ super(SE_Module, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
+
+
+# Bneck 模块(修改:动态生成 SE_Module)
+class Bneck(nn.Module):
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s):
+ super(Bneck, self).__init__()
+ self.stride = s
+
+ self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(expand_size)
+ self.nolinear1 = nolinear
+
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s,
+ padding=kernel_size // 2, groups=expand_size, bias=False)
+ self.bn2 = nn.BatchNorm2d(expand_size)
+ self.nolinear2 = nolinear
+
+ # 动态生成 SE 模块
+ self.se = SE_Module(expand_size) if use_se else None
+
+ self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(out_size)
+
+ self.shortcut = (self.stride == 1 and in_size == out_size)
+
+ def forward(self, x):
+ out = self.nolinear1(self.bn1(self.conv1(x)))
+ out = self.nolinear2(self.bn2(self.conv2(out)))
+ if self.se is not None:
+ out = self.se(out)
+ out = self.bn3(self.conv3(out))
+ if self.shortcut:
+ return x + out
+ else:
+ return out
+
+
+class MobileNetV3_Large(nn.Module):
+ def __init__(self, num_classes=1000):
+ super(MobileNetV3_Large, self).__init__()
+ self.num_classes = num_classes
+ self.init_params()
+
+ # stem
+ self.top = nn.Sequential(
+ nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(16),
+ hswish()
+ )
+
+ # bottlenecks(修改:use_se 参数替代 semodule)
+ self.bneck = nn.Sequential(
+ Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1),
+ Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2),
+ Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1),
+
+ Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2),
+ Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
+ Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1),
+
+ Bneck(3, 40, 240, 80, hswish(), False, 2),
+ Bneck(3, 80, 200, 80, hswish(), False, 1),
+ Bneck(3, 80, 184, 80, hswish(), False, 1),
+ Bneck(3, 80, 184, 80, hswish(), False, 1),
+ Bneck(3, 80, 480, 112, hswish(), True, 1),
+ Bneck(3, 112, 672, 112, hswish(), True, 1),
+
+ Bneck(5, 112, 672, 160, hswish(), True, 1),
+ Bneck(5, 160, 672, 160, hswish(), True, 2),
+ Bneck(5, 160, 960, 160, hswish(), True, 1),
+ )
+
+ # final conv
+ self.bottom = nn.Sequential(
+ nn.Conv2d(160, 960, kernel_size=1, bias=False),
+ nn.BatchNorm2d(960),
+ hswish()
+ )
+
+ # classifier
+ self.last = nn.Sequential(
+ nn.Linear(960, 1280),
+ nn.BatchNorm1d(1280),
+ hswish()
+ )
+ self.linear = nn.Linear(1280, num_classes)
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.001)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ out = self.top(x)
+ out = self.bneck(out)
+ out = self.bottom(out)
+ out = F.avg_pool2d(out, out.size(2)) # 自适应池化
+ out = out.view(out.size(0), -1)
+ out = self.last(out)
+ out = self.linear(out)
+ return out
+
+ @staticmethod
+ def from_pretrained_for_binary():
+ """从 torchvision 自动下载 ImageNet 预训练,并改成二分类"""
+ print("Downloading official torchvision MobileNetV3-Large pretrained weights...")
+ from torchvision.models import mobilenet_v3_large
+
+ official_model = mobilenet_v3_large(pretrained=True)
+ model = MobileNetV3_Large(num_classes=1000)
+ model.load_state_dict(official_model.state_dict(), strict=False)
+
+ # 替换分类头
+ in_features = model.linear.in_features
+ model.linear = nn.Linear(in_features, 2)
+ print(f"Replaced classifier head: {in_features} -> 2")
+ return model
+
+
+if __name__ == "__main__":
+ # 测试二分类模型
+ model = MobileNetV3_Large.from_pretrained_for_binary()
+ x = torch.randn(4, 3, 224, 224)
+ y = model(x)
+ print(y.shape) # [4, 2]
diff --git a/mobilenetv3_binary.pth b/mobilenetv3_binary.pth
new file mode 100644
index 0000000..56711ee
Binary files /dev/null and b/mobilenetv3_binary.pth differ
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..dcf2aca
--- /dev/null
+++ b/train.py
@@ -0,0 +1,89 @@
+import argparse
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.utils.data import DataLoader
+from mobilenetv3 import MobileNetV3_Large
+
+def train(args):
+ # 数据增强 & 预处理
+ transform_train = transforms.Compose([
+ transforms.Resize((224, 224)),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ transform_val = transforms.Compose([
+ transforms.Resize((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+
+ # 加载数据集
+ train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train)
+ val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val)
+
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
+
+ # 模型
+ model = MobileNetV3_Large(num_classes=2)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device)
+
+ # 损失函数 & 优化器
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
+
+ # 训练
+ for epoch in range(args.epochs):
+ model.train()
+ running_loss = 0.0
+ correct, total = 0, 0
+
+ for images, labels in train_loader:
+ images, labels = images.to(device), labels.to(device)
+
+ optimizer.zero_grad()
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item()
+ _, predicted = outputs.max(1)
+ total += labels.size(0)
+ correct += predicted.eq(labels).sum().item()
+
+ train_acc = 100. * correct / total
+ print(f"[Epoch {epoch+1}/{args.epochs}] Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%")
+
+ # 验证
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for images, labels in val_loader:
+ images, labels = images.to(device), labels.to(device)
+ outputs = model(images)
+ _, predicted = outputs.max(1)
+ total += labels.size(0)
+ correct += predicted.eq(labels).sum().item()
+
+ val_acc = 100. * correct / total
+ print(f"Validation Acc: {val_acc:.2f}%")
+
+ # 保存模型
+ torch.save(model.state_dict(), "mobilenetv3_binary.pth")
+ print("训练完成,模型已保存到 mobilenetv3_binary.pth")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data", type=str, default="./data", help="数据集路径")
+ parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=32, help="批大小")
+ parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
+ args = parser.parse_args()
+ train(args)