Files
zjsh_classification/class3/classify.py
2025-08-14 18:27:52 +08:00

44 lines
1.4 KiB
Python

import torch
from torchvision import transforms
from PIL import Image
from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名
# -------------------------------
# 1. 定义推理设备
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------
# 2. 创建模型并加载权重
# -------------------------------
model = MobileNetV3_Large(num_classes=3)
model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device))
model.to(device)
model.eval() # 推理模式
# -------------------------------
# 3. 图片预处理
# -------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict(image_path):
img = Image.open(image_path).convert('RGB')
x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度
with torch.no_grad():
outputs = model(x)
probs = torch.softmax(outputs, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
return pred_class, probs.cpu().numpy()
# -------------------------------
# 4. 测试推理
# -------------------------------
image_path = "test/2.png"
cls, prob = predict(image_path)
print(f"Predicted class: {cls}, Probabilities: {prob}")