# -*- coding: utf-8 -*- """ ROI RKNN 图像分类模块 基于 RKNNLite 对输入图像的指定 ROI 区域进行分类, 输出类别 ID(0: 异常,1: 正常)。 支持: - RKNN 模型单例加载 - ROI 裁剪与缩放 - BGR → RGB 预处理 - 主程序测试入口 """ import os from typing import Dict,Union import cv2 import numpy as np from rknnlite.api import RKNNLite # ===================================================== # 全局配置(常量) # ===================================================== RKNN_MODEL_PATH: str = "wood_exist_cls.rknn" # ROI 坐标:x1, y1, x2, y2(像素坐标) ROI: tuple[int, int, int, int] = (3, 0, 694, 182) CLASS_NAMES: Dict[int, str] = { 0: "异常", 1: "正常", } # ===================================================== # 全局 RKNN 实例(单例) # ===================================================== _global_rknn: Union[RKNNLite, None] = None def _init_rknn_model(model_path: str) -> RKNNLite: """ 初始化并返回 RKNN 模型(单例模式)。 Args: model_path (str): RKNN 模型路径 Returns: RKNNLite: 已初始化的 RKNNLite 实例 Raises: FileNotFoundError: 模型文件不存在 RuntimeError: RKNN 加载或运行时初始化失败 """ global _global_rknn if _global_rknn is not None: return _global_rknn if not os.path.exists(model_path): raise FileNotFoundError(f"RKNN 模型不存在: {model_path}") rknn = RKNNLite(verbose=False) ret = rknn.load_rknn(model_path) if ret != 0: raise RuntimeError(f"Load RKNN failed: {ret}") ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0) if ret != 0: raise RuntimeError(f"Init runtime failed: {ret}") _global_rknn = rknn print(f"[INFO] RKNN 模型加载成功: {model_path}") return rknn # ===================================================== # 预处理函数 # ===================================================== def _preprocess_input(img: np.ndarray) -> np.ndarray: """ 对 ROI 图像进行模型输入预处理。 Args: img (np.ndarray): BGR 格式图像,shape=(640, 640, 3) Returns: np.ndarray: NHWC 格式 float32 输入张量,shape=(1, 640, 640, 3) Raises: ValueError: 输入图像尺寸不符合要求 """ if img.shape[:2] != (640, 640): raise ValueError("输入图像必须是 640x640") img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) input_tensor = np.expand_dims(img_rgb.astype(np.float32), axis=0) return np.ascontiguousarray(input_tensor) # ===================================================== # ROI + RKNN 分类器 # ===================================================== class ROIClassifierRKNN: """ 基于 RKNN 的 ROI 区域分类器。 功能: - 加载 RKNN 分类模型 - 从原始图像中裁剪 ROI - 执行分类推理并返回类别 ID """ def __init__(self, model_path: str) -> None: """ 初始化分类器。 Args: model_path (str): RKNN 模型路径 """ self.rknn = _init_rknn_model(model_path) def classify(self, img_np: np.ndarray) -> int: """ 对输入图像进行 ROI 分类。 Args: img_np (np.ndarray): 原始 BGR 图像 Returns: int: 分类结果(0: 异常,1: 正常) Raises: ValueError: ROI 坐标非法 """ height, width = img_np.shape[:2] x1, y1, x2, y2 = ROI # -------- ROI 边界保护 -------- x1 = max(0, min(x1, width - 1)) x2 = max(0, min(x2, width)) y1 = max(0, min(y1, height - 1)) y2 = max(0, min(y2, height)) if x2 <= x1 or y2 <= y1: raise ValueError(f"ROI 坐标无效: {(x1, y1, x2, y2)}") # -------- 1. 裁剪 ROI -------- roi_img = img_np[y1:y2, x1:x2] # -------- 2. resize 到 640×640 -------- roi_img = cv2.resize(roi_img, (640, 640)) # -------- 3. 预处理 -------- input_tensor = _preprocess_input(roi_img) # -------- 4. RKNN 推理 -------- outputs = self.rknn.inference([input_tensor]) logits = outputs[0].reshape(-1).astype(float) return int(np.argmax(logits)) # ===================================================== # 对外接口 # ===================================================== _classifier = ROIClassifierRKNN(RKNN_MODEL_PATH) def classify_wood_exist(img_np: np.ndarray) -> int: """ 线条是否存在图像分类接口函数。 Args: img_np (np.ndarray): 原始 BGR 图像 Returns: int: 分类结果(0 / 1) """ return _classifier.classify(img_np) # ===================================================== # 测试入口 # ===================================================== if __name__ == "__main__": img_path = "1.png" if not os.path.exists(img_path): raise FileNotFoundError(f"图片不存在: {img_path}") img = cv2.imread(img_path) if img is None: raise ValueError("图像加载失败") result = classify_wood_exist(img) print( f"NG料结果:{result} " f"({CLASS_NAMES[result]})" )