import os import cv2 import numpy as np from collections import deque class StableClassJudge: """ 连续三帧稳定判决器: - class0 / class1 连续 3 帧 -> 输出 - class2 -> 清空计数,重新统计 """ def __init__(self, stable_frames=3, ignore_class=2): self.stable_frames = stable_frames self.ignore_class = ignore_class self.buffer = deque(maxlen=stable_frames) def reset(self): self.buffer.clear() def update(self, class_id): if class_id == self.ignore_class: self.reset() return None self.buffer.append(class_id) if len(self.buffer) < self.stable_frames: return None if len(set(self.buffer)) == 1: stable_class = self.buffer[0] self.reset() return stable_class return None # --------------------------- # 三分类映射 # --------------------------- CLASS_NAMES = { 0: "插好", 1: "未插好", 2: "有遮挡" } # --------------------------- # RKNN 全局实例(只加载一次) # --------------------------- _global_rknn = None def init_rknn_model(model_path): from rknnlite.api import RKNNLite global _global_rknn if _global_rknn is not None: return _global_rknn rknn = RKNNLite(verbose=False) ret = rknn.load_rknn(model_path) if ret != 0: raise RuntimeError(f"Load RKNN failed: {ret}") ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0) if ret != 0: raise RuntimeError(f"Init runtime failed: {ret}") _global_rknn = rknn print(f"[INFO] RKNN 模型加载成功:{model_path}") return rknn # --------------------------- # 预处理(输入 uint8,RKNN 内部转 float32) # --------------------------- def resize_stretch(image, size=640): return cv2.resize(image, (size, size)) def preprocess_image_for_rknn(img, size=640): # 输入必须是 uint8 [0,255],即使模型是 float32 img_resized = resize_stretch(img, size=size) img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) input_tensor = np.expand_dims(img_rgb, axis=0).astype(np.uint8) # NHWC, uint8 return np.ascontiguousarray(input_tensor) # --------------------------- # 单次 RKNN 推理(三分类,float32 模型) # --------------------------- def rknn_classify_preprocessed(input_tensor, model_path): rknn = init_rknn_model(model_path) outs = rknn.inference([input_tensor]) # 直接得到 logits probs = outs[0].flatten().astype(np.float32) # shape: (3,) class_id = int(np.argmax(probs)) return class_id, probs # --------------------------- # 单张图片推理 # --------------------------- def classify_single_image(frame, model_path, size=640): if frame is None: raise FileNotFoundError("输入帧为空") input_tensor = preprocess_image_for_rknn(frame, size=size) class_id, probs = rknn_classify_preprocessed(input_tensor, model_path) class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})") return { "class_id": class_id, "class": class_name, "score": round(float(probs[class_id]), 4), "raw": probs.tolist() } # --------------------------- # 示例调用 # --------------------------- if __name__ == "__main__": model_path = "charge0324.rknn" image_path = "class2.jpg" frame = cv2.imread(image_path) if frame is None: raise FileNotFoundError(f"无法读取图片:{image_path}") result = classify_single_image(frame, model_path) print("[RESULT]", result)