Files
Feeding_control_system/vision/charge_3cls/charge_cls_rknn.py
2026-04-07 09:51:38 +08:00

132 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from collections import deque
class StableClassJudge:
"""
连续三帧稳定判决器:
- class0 / class1 连续 3 帧 -> 输出
- class2 -> 清空计数,重新统计
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
if len(self.buffer) < self.stable_frames:
return None
if len(set(self.buffer)) == 1:
stable_class = self.buffer[0]
self.reset()
return stable_class
return None
# ---------------------------
# 三分类映射
# ---------------------------
CLASS_NAMES = {
0: "插好",
1: "未插好",
2: "有遮挡"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
def init_rknn_model(model_path):
from rknnlite.api import RKNNLite
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功:{model_path}")
return rknn
# ---------------------------
# 预处理(输入 uint8RKNN 内部转 float32
# ---------------------------
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(img, size=640):
# 输入必须是 uint8 [0,255],即使模型是 float32
img_resized = resize_stretch(img, size=size)
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
input_tensor = np.expand_dims(img_rgb, axis=0).astype(np.uint8) # NHWC, uint8
return np.ascontiguousarray(input_tensor)
# ---------------------------
# 单次 RKNN 推理三分类float32 模型)
# ---------------------------
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
outs = rknn.inference([input_tensor])
# 直接得到 logits
probs = outs[0].flatten().astype(np.float32) # shape: (3,)
class_id = int(np.argmax(probs))
return class_id, probs
# ---------------------------
# 单张图片推理
# ---------------------------
def classify_single_image(frame, model_path, size=640):
if frame is None:
raise FileNotFoundError("输入帧为空")
input_tensor = preprocess_image_for_rknn(frame, size=size)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
return {
"class_id": class_id,
"class": class_name,
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# ---------------------------
# 示例调用
# ---------------------------
if __name__ == "__main__":
model_path = "charge0324.rknn"
image_path = "class2.jpg"
frame = cv2.imread(image_path)
if frame is None:
raise FileNotFoundError(f"无法读取图片:{image_path}")
result = classify_single_image(frame, model_path)
print("[RESULT]", result)