# -*- coding: utf-8 -*- """ 木条检测模块(基于RKNNLite + ROI + 单类别NMS) 功能: - 检测ROI区域内木条数量 - 输出实际木条数量 """ import os from typing import Optional import cv2 import numpy as np from rknnlite.api import RKNNLite # ===================================================== # 常量配置 # ===================================================== RKNN_MODEL_PATH: str = "wood_detect.rknn" ROI: tuple[int, int, int, int] = (1, 1, 1000, 1000) # ROI 坐标(x1, y1, x2, y2) IMG_SIZE: tuple[int, int] = (640, 640) # 模型输入大小 OBJ_THRESH: float = 0.25 # 置信度阈值 NMS_THRESH: float = 0.45 # NMS阈值 CLASS_NAME: list[str] = ["bag"] # 单类别名称,后续改成wood # ===================================================== # 私有工具函数 # ===================================================== def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: """计算softmax概率""" x = x - np.max(x, axis=axis, keepdims=True) exp_x = np.exp(x) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def _letterbox_resize(image: np.ndarray, size: tuple[int, int], bg_color: int = 114) -> tuple[np.ndarray, float, int, int]: """保持长宽比缩放并填充到目标尺寸""" target_w, target_h = size h, w = image.shape[:2] scale = min(target_w / w, target_h / h) new_w, new_h = int(w * scale), int(h * scale) resized = cv2.resize(image, (new_w, new_h)) canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8) dx = (target_w - new_w) // 2 dy = (target_h - new_h) // 2 canvas[dy:dy + new_h, dx:dx + new_w] = resized return canvas, scale, dx, dy def _nms(boxes: np.ndarray, scores: np.ndarray, thresh: float) -> list[int]: """非极大值抑制""" x1, y1, x2, y2 = boxes.T areas = (x2 - x1) * (y2 - y1) order = scores.argsort()[::-1] keep: list[int] = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1) iou = inter / (areas[i] + areas[order[1:]] - inter) order = order[1:][iou <= thresh] return keep def _post_process(outputs: list[np.ndarray], scale: float, dx: int, dy: int) -> Optional[np.ndarray]: """RKNN模型输出后处理:解码坐标 + NMS,返回有效boxes""" boxes_list, scores_list = [], [] strides = [8, 16, 32] for i, stride in enumerate(strides): reg = outputs[i * 3 + 0][0] cls = outputs[i * 3 + 1][0] obj = outputs[i * 3 + 2][0] num_classes, H, W = cls.shape reg_max = reg.shape[0] // 4 grid_y, grid_x = np.meshgrid(np.arange(H), np.arange(W), indexing="ij") grid_x = grid_x.astype(np.float32).ravel() grid_y = grid_y.astype(np.float32).ravel() cls_flat = cls.reshape(num_classes, -1).T obj_flat = obj.ravel() max_cls_score = np.max(cls_flat, axis=1) scores = max_cls_score * obj_flat valid_mask = scores >= OBJ_THRESH if not np.any(valid_mask): continue valid_idx = np.where(valid_mask)[0] scores_v = scores[valid_idx] gx = grid_x[valid_idx] gy = grid_y[valid_idx] reg_valid = reg.reshape(4, reg_max, -1)[:, :, valid_idx] reg_softmax = _softmax(reg_valid, axis=1) acc = np.arange(reg_max, dtype=np.float32).reshape(1, -1, 1) distance = np.sum(reg_softmax * acc, axis=1) cx = (gx + 0.5) * stride cy = (gy + 0.5) * stride l, t, r, b = distance[0], distance[1], distance[2], distance[3] x1 = cx - l * stride y1 = cy - t * stride x2 = cx + r * stride y2 = cy + b * stride boxes = np.stack([x1, y1, x2, y2], axis=1) boxes[:, [0, 2]] = (boxes[:, [0, 2]] - dx) / scale boxes[:, [1, 3]] = (boxes[:, [1, 3]] - dy) / scale boxes_list.append(boxes) scores_list.append(scores_v) if not boxes_list: return None boxes_all = np.concatenate(boxes_list, axis=0) scores_all = np.concatenate(scores_list, axis=0) keep_idx = _nms(boxes_all, scores_all, NMS_THRESH) return boxes_all[keep_idx] # ===================================================== # RKNN模型全局初始化 # ===================================================== _global_rknn = RKNNLite() _global_rknn.load_rknn(RKNN_MODEL_PATH) _global_rknn.init_runtime() # ===================================================== # 木条检测类 # ===================================================== class WoodDetectorRKNN: """基于RKNN的木条检测器""" def __init__(self): """初始化木条检测器""" self.rknn = _global_rknn def detect(self, img_np: np.ndarray) -> int: """ 检测木条数量 Args: img_np (np.ndarray): 原始BGR图像 Returns: int: 检测到的木条数量 """ # ROI裁剪 h, w = img_np.shape[:2] x1, y1, x2, y2 = ROI x1 = max(0, min(x1, w - 1)) x2 = max(0, min(x2, w)) y1 = max(0, min(y1, h - 1)) y2 = max(0, min(y2, h)) roi_img = img_np[y1:y2, x1:x2] # letterbox resize resized_img, scale, dx, dy = _letterbox_resize(roi_img, IMG_SIZE) # 推理 input_tensor = np.expand_dims(resized_img.astype(np.float32), axis=0) outputs = self.rknn.inference([input_tensor]) boxes = _post_process(outputs, scale, dx, dy) if boxes is None: return 0 return len(boxes) # ===================================================== # 对外统一接口 # ===================================================== _detector = WoodDetectorRKNN() def detect_wood(img_np: np.ndarray) -> int: """ 对外木条检测接口(返回木条数量) Args: img_np (np.ndarray): 原始BGR图像 Returns: int: 检测到的木条数量 """ return _detector.detect(img_np) # ===================================================== # 主程序入口 # ===================================================== if __name__ == "__main__": img_path = "1.jpg" if not os.path.exists(img_path): raise FileNotFoundError(f"图片不存在: {img_path}") img = cv2.imread(img_path) if img is None: raise ValueError("图像加载失败") total_count = detect_wood(img) print(f"检测到木条数量: {total_count}")