import torch from torchvision import transforms from PIL import Image from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名 # ------------------------------- # 1. 定义推理设备 # ------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------- # 2. 创建模型并加载权重 # ------------------------------- model = MobileNetV3_Large(num_classes=3) model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device)) model.to(device) model.eval() # 推理模式 # ------------------------------- # 3. 图片预处理 # ------------------------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image_path): img = Image.open(image_path).convert('RGB') x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度 with torch.no_grad(): outputs = model(x) probs = torch.softmax(outputs, dim=1) pred_class = torch.argmax(probs, dim=1).item() return pred_class, probs.cpu().numpy() # ------------------------------- # 4. 测试推理 # ------------------------------- image_path = "test/2.png" cls, prob = predict(image_path) print(f"Predicted class: {cls}, Probabilities: {prob}")