import os import cv2 import numpy as np from rknnlite.api import RKNNLite from collections import deque class StableClassJudge: """ 连续三帧稳定判决器: - class0 / class1 连续 3 帧 -> 输出 - class2 -> 清空计数,重新统计 """ def __init__(self, stable_frames=3, ignore_class=2): self.stable_frames = stable_frames self.ignore_class = ignore_class self.buffer = deque(maxlen=stable_frames) def reset(self): self.buffer.clear() def update(self, class_id): """ 输入单帧分类结果 返回: - None:尚未稳定 - class_id:稳定输出结果 """ # 遇到 class2,直接清空重新计数 if class_id == self.ignore_class: self.reset() return None self.buffer.append(class_id) # 缓冲未满 if len(self.buffer) < self.stable_frames: return None # 三帧完全一致 if len(set(self.buffer)) == 1: stable_class = self.buffer[0] self.reset() # 输出一次后重新计数(防止重复触发) return stable_class return None # --------------------------- # 三分类映射 # --------------------------- CLASS_NAMES = { 0: "模具车1", 1: "模具车2", 2: "有遮挡" } # --------------------------- # RKNN 全局实例(只加载一次) # --------------------------- _global_rknn = None def init_rknn_model(model_path): global _global_rknn if _global_rknn is not None: return _global_rknn rknn = RKNNLite(verbose=False) ret = rknn.load_rknn(model_path) if ret != 0: raise RuntimeError(f"Load RKNN failed: {ret}") ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0) if ret != 0: raise RuntimeError(f"Init runtime failed: {ret}") _global_rknn = rknn print(f"[INFO] RKNN 模型加载成功: {model_path}") return rknn # --------------------------- # 预处理 # --------------------------- def letterbox(image, new_size=640, color=(114, 114, 114)): h, w = image.shape[:2] scale = min(new_size / h, new_size / w) nh, nw = int(h * scale), int(w * scale) resized = cv2.resize(image, (nw, nh)) new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8) top = (new_size - nh) // 2 left = (new_size - nw) // 2 new_img[top:top + nh, left:left + nw] = resized return new_img def resize_stretch(image, size=640): return cv2.resize(image, (size, size)) def preprocess_image_for_rknn( img, size=640, resize_mode="stretch", to_rgb=True, normalize=False, layout="NHWC" ): if resize_mode == "letterbox": img_box = letterbox(img, new_size=size) else: img_box = resize_stretch(img, size=size) if to_rgb: img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB) img_f = img_box.astype(np.float32) if normalize: img_f /= 255.0 if layout == "NHWC": out = np.expand_dims(img_f, axis=0) else: out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0) return np.ascontiguousarray(out) # --------------------------- # 单次 RKNN 推理(三分类) # --------------------------- def rknn_classify_preprocessed(input_tensor, model_path): rknn = init_rknn_model(model_path) outs = rknn.inference([input_tensor]) logits = outs[0].reshape(-1).astype(np.float32) # shape = (3,) # softmax exp = np.exp(logits - np.max(logits)) probs = exp / np.sum(exp) class_id = int(np.argmax(probs)) return class_id, probs # --------------------------- # ROI # --------------------------- def load_single_roi(txt_path): if not os.path.exists(txt_path): raise RuntimeError(f"ROI 文件不存在: {txt_path}") with open(txt_path) as f: for line in f: s = line.strip() if not s: continue x, y, w, h = map(int, s.split(',')) return (x, y, w, h) raise RuntimeError("ROI 文件为空") def crop_and_return_roi(img, roi): x, y, w, h = roi h_img, w_img = img.shape[:2] if x < 0 or y < 0 or x + w > w_img or y + h > h_img: raise RuntimeError(f"ROI 超出图像范围: {roi}") return img[y:y + h, x:x + w] # --------------------------- # 单张图片推理(三分类) # --------------------------- def classify_single_image( frame, model_path, roi_file, size=640, resize_mode="stretch", to_rgb=True, normalize=False, layout="NHWC" ): if frame is None: raise FileNotFoundError("❌ 输入帧为空") roi = load_single_roi(roi_file) roi_img = crop_and_return_roi(frame, roi) input_tensor = preprocess_image_for_rknn( roi_img, size=size, resize_mode=resize_mode, to_rgb=to_rgb, normalize=normalize, layout=layout ) class_id, probs = rknn_classify_preprocessed(input_tensor, model_path) class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})") return { "class_id": class_id, "class": class_name, "score": round(float(probs[class_id]), 4), "raw": probs.tolist() } # --------------------------- # 示例调用 # --------------------------- if __name__ == "__main__": model_path = "muju_cls.rknn" roi_file = "./roi_coordinates/muju_roi.txt" image_path = "./test_image/test.png" frame = cv2.imread(image_path) if frame is None: raise FileNotFoundError(f"❌ 无法读取图片: {image_path}") result = classify_single_image(frame, model_path, roi_file) print("[RESULT]", result) # --------------------------- # 示例判断逻辑 ''' import cv2 from muju_cls_rknn import classify_single_image,StableClassJudge,CLASS_NAMES def run_stable_classification_loop( model_path, roi_file, image_source, stable_frames=3 ): """ image_source: - cv2.VideoCapture """ judge = StableClassJudge( stable_frames=stable_frames, ignore_class=2 # 有遮挡 ) cap = image_source if not hasattr(cap, "read"): raise TypeError("image_source 必须是 cv2.VideoCapture") while True: ret, frame = cap.read() if not ret: print("读取帧失败,退出") break result = classify_single_image(frame, model_path, roi_file) class_id = result["class_id"] class_name = result["class"] score = result["score"] print(f"[FRAME] {class_name} conf={score}") stable = judge.update(class_id) if stable is not None: print(f"\n稳定输出: {CLASS_NAMES[stable]} \n") if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() ''' # ---------------------------