commit c1b5481a2957a96990b26a57da3cd4edcf466509 Author: 琉璃月光 <15630071+llyg777@user.noreply.gitee.com> Date: Wed Aug 13 18:03:52 2025 +0800 feat: initial commit diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..359bb53 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/.idea/class.iml b/.idea/class.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/class.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..9de2865 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..247835b --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..6c0b863 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/1.png b/1.png new file mode 100644 index 0000000..8df7feb Binary files /dev/null and b/1.png differ diff --git a/2.png b/2.png new file mode 100644 index 0000000..0512d78 Binary files /dev/null and b/2.png differ diff --git a/__pycache__/mobilenetv3.cpython-39.pyc b/__pycache__/mobilenetv3.cpython-39.pyc new file mode 100644 index 0000000..515e2fd Binary files /dev/null and b/__pycache__/mobilenetv3.cpython-39.pyc differ diff --git a/classify.py b/classify.py new file mode 100644 index 0000000..83d6bbd --- /dev/null +++ b/classify.py @@ -0,0 +1,43 @@ +import torch +from torchvision import transforms +from PIL import Image +from mobilenetv3 import MobileNetV3_Large # 你的模型文件名 + +# ------------------------------- +# 1. 定义推理设备 +# ------------------------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# ------------------------------- +# 2. 创建模型并加载权重 +# ------------------------------- +model = MobileNetV3_Large(num_classes=2) +model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device)) +model.to(device) +model.eval() # 推理模式 + +# ------------------------------- +# 3. 图片预处理 +# ------------------------------- +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) +]) + +def predict(image_path): + img = Image.open(image_path).convert('RGB') + x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度 + with torch.no_grad(): + outputs = model(x) + probs = torch.softmax(outputs, dim=1) + pred_class = torch.argmax(probs, dim=1).item() + return pred_class, probs.cpu().numpy() + +# ------------------------------- +# 4. 测试推理 +# ------------------------------- +image_path = "2.png" +cls, prob = predict(image_path) +print(f"Predicted class: {cls}, Probabilities: {prob}") diff --git a/copy_jpg_images.py b/copy_jpg_images.py new file mode 100644 index 0000000..1092410 --- /dev/null +++ b/copy_jpg_images.py @@ -0,0 +1,67 @@ +import os +import shutil +from pathlib import Path + +def copy_png_images(source_folder, destination_folder): + """ + 从源文件夹递归搜索所有 .png 文件并复制到目标文件夹。 + + 参数: + source_folder (str): 源文件夹路径 + destination_folder (str): 目标文件夹路径 + """ + # 将字符串路径转换为 Path 对象,更方便操作 + src_path = Path(source_folder) + dest_path = Path(destination_folder) + + # 检查源文件夹是否存在 + if not src_path.exists(): + print(f"错误:源文件夹 '{source_folder}' 不存在。") + return + + if not src_path.is_dir(): + print(f"错误:'{source_folder}' 不是一个有效的文件夹。") + return + + # 如果目标文件夹不存在,则创建它 + dest_path.mkdir(parents=True, exist_ok=True) + + # 用于统计复制的文件数量 + copied_count = 0 + + # 使用 rglob 递归搜索所有 .png 文件(不区分大小写) + # rglob('*.[pP][nN][gG]') 可以匹配 .png, .PNG, .Png 等 + png_files = src_path.rglob('*.[jJ][pP][gG]') + + for png_file in png_files: + try: + # 计算目标文件的完整路径 + # 保持源文件夹的相对结构,但只保留文件名(可选:如果想保持结构,去掉 .name) + # 这里我们只复制文件名到目标文件夹,避免路径过长或结构复杂 + dest_file = dest_path / png_file.name + + # 处理重名文件:在文件名后添加序号 + counter = 1 + original_dest_file = dest_file + while dest_file.exists(): + dest_file = original_dest_file.parent / f"{original_dest_file.stem}_{counter}{original_dest_file.suffix}" + counter += 1 + + # 执行复制 + shutil.copy2(png_file, dest_file) # copy2 会保留文件的元数据(如修改时间) + print(f"已复制: {png_file} -> {dest_file}") + copied_count += 1 + + except Exception as e: + print(f"复制文件时出错: {png_file} - {e}") + + print(f"\n复制完成!共复制了 {copied_count} 个 PNG 文件到 '{destination_folder}'。") + +# ------------------ 主程序 ------------------ +if __name__ == "__main__": + # ====== 请在这里修改源文件夹和目标文件夹的路径 ====== + source_folder = r"/home/hx/桌面/git/class/data/folder_end" # 替换为你的源文件夹路径 + destination_folder = r"/home/hx/桌面/git/class/data/class0" # 替换为你的目标文件夹路径 + # ==================================================== + + copy_png_images(source_folder, destination_folder) \ No newline at end of file diff --git a/data_load.py b/data_load.py new file mode 100644 index 0000000..6663191 --- /dev/null +++ b/data_load.py @@ -0,0 +1,24 @@ +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +# 图像预处理 +transform = transforms.Compose([ + transforms.Resize((224, 224)), # 调整为 MobileNet 输入尺寸 + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) # ImageNet 预训练均值方差 +]) + +# 数据集路径 +train_dir = 'dataset_root/train' +val_dir = 'dataset_root/val' + +# 加载数据集 +train_dataset = datasets.ImageFolder(train_dir, transform=transform) +val_dataset = datasets.ImageFolder(val_dir, transform=transform) + +# DataLoader +train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) + +print(train_dataset.class_to_idx) # {'class0': 0, 'class1': 1} diff --git a/mobilenetv3.py b/mobilenetv3.py new file mode 100644 index 0000000..6fb6574 --- /dev/null +++ b/mobilenetv3.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +# h-swish 激活函数 +class hswish(nn.Module): + def forward(self, x): + return x * F.relu6(x + 3, inplace=True) / 6 + + +# SE 模块 +class SE_Module(nn.Module): + def __init__(self, channel, reduction=4): + super(SE_Module, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +# Bneck 模块(修改:动态生成 SE_Module) +class Bneck(nn.Module): + def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, use_se, s): + super(Bneck, self).__init__() + self.stride = s + + self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(expand_size) + self.nolinear1 = nolinear + + self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride=s, + padding=kernel_size // 2, groups=expand_size, bias=False) + self.bn2 = nn.BatchNorm2d(expand_size) + self.nolinear2 = nolinear + + # 动态生成 SE 模块 + self.se = SE_Module(expand_size) if use_se else None + + self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(out_size) + + self.shortcut = (self.stride == 1 and in_size == out_size) + + def forward(self, x): + out = self.nolinear1(self.bn1(self.conv1(x))) + out = self.nolinear2(self.bn2(self.conv2(out))) + if self.se is not None: + out = self.se(out) + out = self.bn3(self.conv3(out)) + if self.shortcut: + return x + out + else: + return out + + +class MobileNetV3_Large(nn.Module): + def __init__(self, num_classes=1000): + super(MobileNetV3_Large, self).__init__() + self.num_classes = num_classes + self.init_params() + + # stem + self.top = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(16), + hswish() + ) + + # bottlenecks(修改:use_se 参数替代 semodule) + self.bneck = nn.Sequential( + Bneck(3, 16, 16, 16, nn.ReLU(True), False, 1), + Bneck(3, 16, 64, 24, nn.ReLU(True), False, 2), + Bneck(3, 24, 72, 24, nn.ReLU(True), False, 1), + + Bneck(5, 24, 72, 40, nn.ReLU(True), True, 2), + Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1), + Bneck(5, 40, 120, 40, nn.ReLU(True), True, 1), + + Bneck(3, 40, 240, 80, hswish(), False, 2), + Bneck(3, 80, 200, 80, hswish(), False, 1), + Bneck(3, 80, 184, 80, hswish(), False, 1), + Bneck(3, 80, 184, 80, hswish(), False, 1), + Bneck(3, 80, 480, 112, hswish(), True, 1), + Bneck(3, 112, 672, 112, hswish(), True, 1), + + Bneck(5, 112, 672, 160, hswish(), True, 1), + Bneck(5, 160, 672, 160, hswish(), True, 2), + Bneck(5, 160, 960, 160, hswish(), True, 1), + ) + + # final conv + self.bottom = nn.Sequential( + nn.Conv2d(160, 960, kernel_size=1, bias=False), + nn.BatchNorm2d(960), + hswish() + ) + + # classifier + self.last = nn.Sequential( + nn.Linear(960, 1280), + nn.BatchNorm1d(1280), + hswish() + ) + self.linear = nn.Linear(1280, num_classes) + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.001) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, x): + out = self.top(x) + out = self.bneck(out) + out = self.bottom(out) + out = F.avg_pool2d(out, out.size(2)) # 自适应池化 + out = out.view(out.size(0), -1) + out = self.last(out) + out = self.linear(out) + return out + + @staticmethod + def from_pretrained_for_binary(): + """从 torchvision 自动下载 ImageNet 预训练,并改成二分类""" + print("Downloading official torchvision MobileNetV3-Large pretrained weights...") + from torchvision.models import mobilenet_v3_large + + official_model = mobilenet_v3_large(pretrained=True) + model = MobileNetV3_Large(num_classes=1000) + model.load_state_dict(official_model.state_dict(), strict=False) + + # 替换分类头 + in_features = model.linear.in_features + model.linear = nn.Linear(in_features, 2) + print(f"Replaced classifier head: {in_features} -> 2") + return model + + +if __name__ == "__main__": + # 测试二分类模型 + model = MobileNetV3_Large.from_pretrained_for_binary() + x = torch.randn(4, 3, 224, 224) + y = model(x) + print(y.shape) # [4, 2] diff --git a/mobilenetv3_binary.pth b/mobilenetv3_binary.pth new file mode 100644 index 0000000..56711ee Binary files /dev/null and b/mobilenetv3_binary.pth differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..dcf2aca --- /dev/null +++ b/train.py @@ -0,0 +1,89 @@ +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from mobilenetv3 import MobileNetV3_Large + +def train(args): + # 数据增强 & 预处理 + transform_train = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + transform_val = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + # 加载数据集 + train_dataset = datasets.ImageFolder(root=f"{args.data}/train", transform=transform_train) + val_dataset = datasets.ImageFolder(root=f"{args.data}/val", transform=transform_val) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + + # 模型 + model = MobileNetV3_Large(num_classes=2) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # 损失函数 & 优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # 训练 + for epoch in range(args.epochs): + model.train() + running_loss = 0.0 + correct, total = 0, 0 + + for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + train_acc = 100. * correct / total + print(f"[Epoch {epoch+1}/{args.epochs}] Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%") + + # 验证 + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for images, labels in val_loader: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + val_acc = 100. * correct / total + print(f"Validation Acc: {val_acc:.2f}%") + + # 保存模型 + torch.save(model.state_dict(), "mobilenetv3_binary.pth") + print("训练完成,模型已保存到 mobilenetv3_binary.pth") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data", type=str, default="./data", help="数据集路径") + parser.add_argument("--epochs", type=int, default=100, help="训练轮数") + parser.add_argument("--batch_size", type=int, default=32, help="批大小") + parser.add_argument("--lr", type=float, default=1e-3, help="学习率") + args = parser.parse_args() + train(args)