44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import torch
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
from class2.test.mobilenetv3 import MobileNetV3_Large # 你的模型文件名
|
|
|
|
# -------------------------------
|
|
# 1. 定义推理设备
|
|
# -------------------------------
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# -------------------------------
|
|
# 2. 创建模型并加载权重
|
|
# -------------------------------
|
|
model = MobileNetV3_Large(num_classes=3)
|
|
model.load_state_dict(torch.load("mobilenetv3_binary.pth", map_location=device))
|
|
model.to(device)
|
|
model.eval() # 推理模式
|
|
|
|
# -------------------------------
|
|
# 3. 图片预处理
|
|
# -------------------------------
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
def predict(image_path):
|
|
img = Image.open(image_path).convert('RGB')
|
|
x = transform(img).unsqueeze(0).to(device) # 增加 batch 维度
|
|
with torch.no_grad():
|
|
outputs = model(x)
|
|
probs = torch.softmax(outputs, dim=1)
|
|
pred_class = torch.argmax(probs, dim=1).item()
|
|
return pred_class, probs.cpu().numpy()
|
|
|
|
# -------------------------------
|
|
# 4. 测试推理
|
|
# -------------------------------
|
|
image_path = "test/2.png"
|
|
cls, prob = predict(image_path)
|
|
print(f"Predicted class: {cls}, Probabilities: {prob}")
|