最新推送
This commit is contained in:
73
zjsh_code/muju_cls/val/main_pc.py
Normal file
73
zjsh_code/muju_cls/val/main_pc.py
Normal file
@ -0,0 +1,73 @@
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# ---------------------------
|
||||
# 配置路径(请按需修改)
|
||||
# ---------------------------
|
||||
MODEL_PATH = "muju.pt"
|
||||
IMAGE_PATH = "./test_img/class0.png" # 替换为你要测试的单张图片路径
|
||||
|
||||
# 类别映射:必须与训练时 data.yaml 的 names 顺序一致
|
||||
CLASS_NAMES = {
|
||||
0: "模具车非f块",
|
||||
1: "模具车f块",
|
||||
2: "有遮挡"
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# 单张图片推理函数
|
||||
# ---------------------------
|
||||
def classify_single_image(model_path, image_path, class_names):
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
print(f"模型加载成功: {model_path}")
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"无法读取图像: {image_path}")
|
||||
return
|
||||
|
||||
print(f"📷 正在推理: {image_path}")
|
||||
|
||||
# 推理
|
||||
results = model(img)
|
||||
probs = results[0].probs.data.cpu().numpy()
|
||||
pred_class_id = int(probs.argmax())
|
||||
pred_label = class_names[pred_class_id]
|
||||
confidence = float(probs[pred_class_id])
|
||||
|
||||
# 输出结果
|
||||
print("\n" + "=" * 40)
|
||||
print(f"🔍 预测结果:")
|
||||
print(f" 类别: {pred_label}")
|
||||
print(f" 置信度: {confidence:.4f}")
|
||||
print(f" 类别ID: {pred_class_id}")
|
||||
print("=" * 40)
|
||||
|
||||
# (可选)在图像上显示结果并保存/显示
|
||||
# 这里我们只打印,不保存。如需可视化,取消下面注释:
|
||||
"""
|
||||
label_text = f"{pred_label} ({confidence:.2f})"
|
||||
cv2.putText(img, label_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
|
||||
cv2.imshow("Result", img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 或保存带标签的图
|
||||
# output_img_path = image_path.replace(".jpg", "_result.jpg")
|
||||
# cv2.imwrite(output_img_path, img)
|
||||
# print(f"带标签图像已保存: {output_img_path}")
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# 运行入口
|
||||
# ---------------------------
|
||||
if __name__ == "__main__":
|
||||
classify_single_image(
|
||||
model_path=MODEL_PATH,
|
||||
image_path=IMAGE_PATH,
|
||||
class_names=CLASS_NAMES
|
||||
)
|
||||
Reference in New Issue
Block a user