Files
zjsh_yolov11/zjsh_code/charge_3cls/val/main_3cls_charge.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

73 lines
2.0 KiB
Python

import cv2
from ultralytics import YOLO
# ---------------------------
# 配置路径(请按需修改)
# ---------------------------
MODEL_PATH = "charge.pt"
IMAGE_PATH = "test/class2.png" # 替换为你要测试的单张图片路径
# 类别映射:必须与训练时 data.yaml 的 names 顺序一致
CLASS_NAMES = {
0: "成功连接",
1: "未连接",
2: "有遮挡"
}
# ---------------------------
# 单张图片推理函数
# ---------------------------
def classify_single_image(model_path, image_path, class_names):
# 加载模型
model = YOLO(model_path)
print(f"模型加载成功: {model_path}")
# 读取图像
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图像: {image_path}")
return
print(f"📷 正在推理: {image_path}")
# 推理
results = model(img)
probs = results[0].probs.data.cpu().numpy()
pred_class_id = int(probs.argmax())
pred_label = class_names[pred_class_id]
confidence = float(probs[pred_class_id])
# 输出结果
print("\n" + "=" * 40)
print(f"🔍 预测结果:")
print(f" 类别: {pred_label}")
print(f" 置信度: {confidence:.4f}")
print(f" 类别ID: {pred_class_id}")
print("=" * 40)
# (可选)在图像上显示结果并保存/显示
# 这里我们只打印,不保存。如需可视化,取消下面注释:
"""
label_text = f"{pred_label} ({confidence:.2f})"
cv2.putText(img, label_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
cv2.imshow("Result", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
# 或保存带标签的图
# output_img_path = image_path.replace(".jpg", "_result.jpg")
# cv2.imwrite(output_img_path, img)
# print(f"带标签图像已保存: {output_img_path}")
"""
# ---------------------------
# 运行入口
# ---------------------------
if __name__ == "__main__":
classify_single_image(
model_path=MODEL_PATH,
image_path=IMAGE_PATH,
class_names=CLASS_NAMES
)