Files
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

89 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
from ultralytics import YOLO
# ---------------------------
# 类别映射(必须与训练时 data.yaml 一致)
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# 加权判断函数(仅用于 class1/class2
# ---------------------------
def weighted_small_large(pred_probs, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred_probs[1])
p2 = float(pred_probs[2])
total = p1 + p2
if total > 0:
score = (w1 * p1 + w2 * p2) / total
else:
score = 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# ---------------------------
# 单张图片推理主函数
# ---------------------------
def classify_single_image(model_path, image_path, threshold=0.5):
# 加载模型
print("🚀 加载模型...")
model = YOLO(model_path)
# 读取图像
img = cv2.imread(image_path)
if img is None:
raise FileNotFoundError(f"❌ 无法读取图像: {image_path}")
print(f"📷 推理图像: {image_path}")
# 整图分类(不裁剪)
results = model(img)
pred_probs = results[0].probs.data.cpu().numpy().flatten()
class_id = int(pred_probs.argmax())
confidence = float(pred_probs[class_id])
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
# 对 小堆料/大堆料 使用加权逻辑
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred_probs, threshold=threshold)
print("\n🔍 检测到堆料区域,使用加权判断:")
print(f" 小堆料概率: {p1:.4f}")
print(f" 大堆料概率: {p2:.4f}")
print(f" 加权得分: {score:.4f} (阈值={threshold})")
else:
final_class = class_name
score = confidence
# 输出最终结果
print("\n" + "="*40)
print(f"最终分类结果: {final_class}")
print(f"置信度/得分: {score:.4f}")
print("="*40)
return final_class, score
# ---------------------------
# 运行入口(请修改路径)
# ---------------------------
if __name__ == "__main__":
MODEL_PATH = "60best.pt"
IMAGE_PATH = "class4.png" # 👈 改成你的单张图片路径
# 可选:调整加权阈值(默认 0.4
THRESHOLD = 0.4
try:
result_class, result_score = classify_single_image(
model_path=MODEL_PATH,
image_path=IMAGE_PATH,
threshold=THRESHOLD
)
except Exception as e:
print(f"程序出错: {e}")