73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
# ---------------------------
|
|
# 配置路径(请按需修改)
|
|
# ---------------------------
|
|
MODEL_PATH = "charge.pt"
|
|
IMAGE_PATH = "test/class2.png" # 替换为你要测试的单张图片路径
|
|
|
|
# 类别映射:必须与训练时 data.yaml 的 names 顺序一致
|
|
CLASS_NAMES = {
|
|
0: "成功连接",
|
|
1: "未连接",
|
|
2: "有遮挡"
|
|
}
|
|
|
|
|
|
# ---------------------------
|
|
# 单张图片推理函数
|
|
# ---------------------------
|
|
def classify_single_image(model_path, image_path, class_names):
|
|
# 加载模型
|
|
model = YOLO(model_path)
|
|
print(f"模型加载成功: {model_path}")
|
|
|
|
# 读取图像
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
print(f"无法读取图像: {image_path}")
|
|
return
|
|
|
|
print(f"📷 正在推理: {image_path}")
|
|
|
|
# 推理
|
|
results = model(img)
|
|
probs = results[0].probs.data.cpu().numpy()
|
|
pred_class_id = int(probs.argmax())
|
|
pred_label = class_names[pred_class_id]
|
|
confidence = float(probs[pred_class_id])
|
|
|
|
# 输出结果
|
|
print("\n" + "=" * 40)
|
|
print(f"🔍 预测结果:")
|
|
print(f" 类别: {pred_label}")
|
|
print(f" 置信度: {confidence:.4f}")
|
|
print(f" 类别ID: {pred_class_id}")
|
|
print("=" * 40)
|
|
|
|
# (可选)在图像上显示结果并保存/显示
|
|
# 这里我们只打印,不保存。如需可视化,取消下面注释:
|
|
"""
|
|
label_text = f"{pred_label} ({confidence:.2f})"
|
|
cv2.putText(img, label_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
|
|
cv2.imshow("Result", img)
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
|
|
# 或保存带标签的图
|
|
# output_img_path = image_path.replace(".jpg", "_result.jpg")
|
|
# cv2.imwrite(output_img_path, img)
|
|
# print(f"带标签图像已保存: {output_img_path}")
|
|
"""
|
|
|
|
|
|
# ---------------------------
|
|
# 运行入口
|
|
# ---------------------------
|
|
if __name__ == "__main__":
|
|
classify_single_image(
|
|
model_path=MODEL_PATH,
|
|
image_path=IMAGE_PATH,
|
|
class_names=CLASS_NAMES
|
|
) |