import os import cv2 import numpy as np from rknnlite.api import RKNNLite from collections import deque class StableClassJudge: """ 连续三帧稳定判决器: - class0 / class1 连续 3 帧 -> 输出 - class2 -> 清空计数,重新统计 """ def __init__(self, stable_frames=3, ignore_class=2): self.stable_frames = stable_frames self.ignore_class = ignore_class self.buffer = deque(maxlen=stable_frames) def reset(self): self.buffer.clear() def update(self, class_id): """ 输入单帧分类结果 返回: - None:尚未稳定 - class_id:稳定输出结果 """ # 遇到 class2,直接清空重新计数 if class_id == self.ignore_class: self.reset() return None self.buffer.append(class_id) # 缓冲未满 if len(self.buffer) < self.stable_frames: return None # 三帧完全一致 if len(set(self.buffer)) == 1: stable_class = self.buffer[0] self.reset() # 输出一次后重新计数(防止重复触发) return stable_class return None # --------------------------- # 三分类映射 # --------------------------- CLASS_NAMES = { 0: "插好", 1: "未插好", 2: "有遮挡" } # --------------------------- # RKNN 全局实例(只加载一次) # --------------------------- _global_rknn = None def init_rknn_model(model_path): global _global_rknn if _global_rknn is not None: return _global_rknn rknn = RKNNLite(verbose=False) ret = rknn.load_rknn(model_path) if ret != 0: raise RuntimeError(f"Load RKNN failed: {ret}") ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0) if ret != 0: raise RuntimeError(f"Init runtime failed: {ret}") _global_rknn = rknn print(f"[INFO] RKNN 模型加载成功:{model_path}") return rknn # --------------------------- # 预处理 # --------------------------- def letterbox(image, new_size=640, color=(114, 114, 114)): h, w = image.shape[:2] scale = min(new_size / h, new_size / w) nh, nw = int(h * scale), int(w * scale) resized = cv2.resize(image, (nw, nh)) new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8) top = (new_size - nh) // 2 left = (new_size - nw) // 2 new_img[top:top + nh, left:left + nw] = resized return new_img def resize_stretch(image, size=640): return cv2.resize(image, (size, size)) def preprocess_image_for_rknn( img, size=640, resize_mode="stretch", to_rgb=True, normalize=False, layout="NHWC" ): if resize_mode == "letterbox": img_box = letterbox(img, new_size=size) else: img_box = resize_stretch(img, size=size) if to_rgb: img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB) img_f = img_box.astype(np.float32) if normalize: img_f /= 255.0 if layout == "NHWC": out = np.expand_dims(img_f, axis=0) else: out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0) return np.ascontiguousarray(out) # --------------------------- # 单次 RKNN 推理(三分类) # --------------------------- def rknn_classify_preprocessed(input_tensor, model_path): rknn = init_rknn_model(model_path) outs = rknn.inference([input_tensor]) logits = outs[0].reshape(-1).astype(np.float32) # shape = (3,) # softmax exp = np.exp(logits - np.max(logits)) probs = exp / np.sum(exp) class_id = int(np.argmax(probs)) return class_id, probs # --------------------------- # 单张图片推理(三分类)- 已移除 ROI 逻辑 # --------------------------- def classify_single_image( frame, model_path, size=640, resize_mode="stretch", to_rgb=True, normalize=False, layout="NHWC" ): if frame is None: raise FileNotFoundError("❌ 输入帧为空") # 直接使用整图,不再裁剪 input_tensor = preprocess_image_for_rknn( frame, size=size, resize_mode=resize_mode, to_rgb=to_rgb, normalize=normalize, layout=layout ) class_id, probs = rknn_classify_preprocessed(input_tensor, model_path) class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})") return { "class_id": class_id, "class": class_name, "score": round(float(probs[class_id]), 4), "raw": probs.tolist() } # --------------------------- # 示例调用 # --------------------------- if __name__ == "__main__": model_path = "charge_cls.rknn" # roi_file 已移除 image_path = "class2.png" frame = cv2.imread(image_path) if frame is None: raise FileNotFoundError(f"❌ 无法读取图片:{image_path}") # 调用 result = classify_single_image(frame, model_path) print("[RESULT]", result)