import os import cv2 import time import numpy as np from datetime import datetime from collections import deque from rknnlite.api import RKNNLite # ===================================================== # 稳定判决器 # ===================================================== class StableClassJudge: """ 连续 N 帧稳定判决: - class0 / class1 连续 N 帧 -> 输出 - class2 -> 清空计数 """ def __init__(self, stable_frames=3, ignore_class=2): self.stable_frames = stable_frames self.ignore_class = ignore_class self.buffer = deque(maxlen=stable_frames) def reset(self): self.buffer.clear() def update(self, class_id): if class_id == self.ignore_class: self.reset() return None self.buffer.append(class_id) if len(self.buffer) < self.stable_frames: return None if len(set(self.buffer)) == 1: stable = self.buffer[0] self.reset() return stable return None # ===================================================== # 类别定义 # ===================================================== CLASS_NAMES = { 0: "模具车1", 1: "模具车2", 2: "有遮挡" } # ===================================================== # RKNN 全局实例 # ===================================================== _global_rknn = None def init_rknn_model(model_path): global _global_rknn if _global_rknn is not None: return _global_rknn rknn = RKNNLite(verbose=False) ret = rknn.load_rknn(model_path) if ret != 0: raise RuntimeError(f"Load RKNN failed: {ret}") ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0) if ret != 0: raise RuntimeError(f"Init runtime failed: {ret}") _global_rknn = rknn print(f"[INFO] RKNN 模型加载成功: {model_path}") return rknn # ===================================================== # 图像预处理 # ===================================================== def letterbox(image, new_size=640, color=(114, 114, 114)): h, w = image.shape[:2] scale = min(new_size / h, new_size / w) nh, nw = int(h * scale), int(w * scale) resized = cv2.resize(image, (nw, nh)) canvas = np.full((new_size, new_size, 3), color, dtype=np.uint8) top = (new_size - nh) // 2 left = (new_size - nw) // 2 canvas[top:top + nh, left:left + nw] = resized return canvas def resize_stretch(image, size=640): return cv2.resize(image, (size, size)) def preprocess_image_for_rknn( img, size=640, resize_mode="stretch", to_rgb=True, normalize=False, layout="NHWC" ): if resize_mode == "letterbox": img = letterbox(img, size) else: img = resize_stretch(img, size) if to_rgb: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) if normalize: img /= 255.0 if layout == "NHWC": img = np.expand_dims(img, axis=0) else: img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) return np.ascontiguousarray(img) # ===================================================== # RKNN 单次推理 # ===================================================== def rknn_classify_preprocessed(input_tensor, model_path): rknn = init_rknn_model(model_path) outs = rknn.inference([input_tensor]) probs = outs[0].reshape(-1).astype(float) class_id = int(np.argmax(probs)) return class_id, probs # ===================================================== # ROI 处理 # ===================================================== def load_single_roi(txt_path): if not os.path.exists(txt_path): raise RuntimeError(f"ROI 文件不存在: {txt_path}") with open(txt_path) as f: for line in f: line = line.strip() if not line: continue x, y, w, h = map(int, line.split(",")) return (x, y, w, h) raise RuntimeError("ROI 文件为空") def crop_and_return_roi(img, roi): x, y, w, h = roi H, W = img.shape[:2] if x < 0 or y < 0 or x + w > W or y + h > H: raise RuntimeError(f"ROI 超出图像范围: {roi}") return img[y:y + h, x:x + w] # ===================================================== # 单帧分类 # ===================================================== def classify_single_image(frame, model_path, roi_file): roi = load_single_roi(roi_file) roi_img = crop_and_return_roi(frame, roi) input_tensor = preprocess_image_for_rknn( roi_img, size=640, resize_mode="stretch", to_rgb=True, normalize=False, layout="NHWC" ) class_id, probs = rknn_classify_preprocessed(input_tensor, model_path) return { "class_id": class_id, "class": CLASS_NAMES[class_id], "score": round(float(probs[class_id]), 4), "raw": probs.tolist() } # ===================================================== # RTSP 推理 + 保存分类结果 # ===================================================== def run_rtsp_classification_and_save( model_path, roi_file, rtsp_url, save_root="clsimg", stable_frames=3, save_mode="all" # all / stable ): for cid in CLASS_NAMES.keys(): os.makedirs(os.path.join(save_root, f"class{cid}"), exist_ok=True) cap = cv2.VideoCapture(rtsp_url) if not cap.isOpened(): raise RuntimeError(f"无法打开 RTSP: {rtsp_url}") judge = StableClassJudge(stable_frames=stable_frames, ignore_class=2) print("[INFO] RTSP 推理开始") while True: ret, frame = cap.read() if not ret: print("[WARN] RTSP 读帧失败") time.sleep(0.1) continue frame = cv2.flip(frame, -1) result = classify_single_image(frame, model_path, roi_file) class_id = result["class_id"] score = result["score"] print(f"[FRAME] {result['class']} conf={score}") stable = judge.update(class_id) save_flag = False save_class = class_id if save_mode == "all": save_flag = True elif save_mode == "stable" and stable is not None: save_flag = True save_class = stable if save_flag: ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"{ts}_conf{score:.2f}.jpg" save_dir = os.path.join(save_root, f"class{save_class}") cv2.imwrite(os.path.join(save_dir, filename), frame) print(f"[SAVE] class{save_class}/{filename}") if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() # ===================================================== # main # ===================================================== if __name__ == "__main__": model_path = "muju_cls.rknn" roi_file = "./roi_coordinates/muju_roi.txt" rtsp_url = "rtsp://admin:XJ123456@192.168.250.61:554/streaming/channels/101" run_rtsp_classification_and_save( model_path=model_path, roi_file=roi_file, rtsp_url=rtsp_url, save_root="clsimg", stable_frames=3, save_mode="all" # 改成 "stable" 只存稳定结果 )