import os import cv2 import numpy as np from rknnlite.api import RKNNLite # ====================== 配置 ====================== MODEL_PATH = "bag3568.rknn" IMG_SIZE = (640, 640) OBJ_THRESH = 0.25 NMS_THRESH = 0.45 CLASS_NAME = ["bag", "bag35"] # ====================== 工具函数 ====================== def softmax(x, axis=-1): x = x - np.max(x, axis=axis, keepdims=True) exp_x = np.exp(x) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def letterbox_resize(image, size, bg_color=114): target_w, target_h = size h, w = image.shape[:2] scale = min(target_w / w, target_h / h) new_w, new_h = int(w * scale), int(h * scale) resized = cv2.resize(image, (new_w, new_h)) canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8) dx = (target_w - new_w) // 2 dy = (target_h - new_h) // 2 canvas[dy:dy + new_h, dx:dx + new_w] = resized return canvas, scale, dx, dy # ====================== DFL 解码 ====================== def dfl_decode(reg): reg = reg.reshape(4, -1) prob = softmax(reg, axis=1) acc = np.arange(reg.shape[1]) return np.sum(prob * acc, axis=1) # ====================== NMS ====================== def nms(boxes, scores, thresh): boxes = np.array(boxes) scores = np.array(scores) x1, y1, x2, y2 = boxes.T areas = (x2 - x1) * (y2 - y1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1) iou = inter / (areas[i] + areas[order[1:]] - inter) order = order[1:][iou <= thresh] return keep # ====================== 后处理 ====================== def post_process(outputs, scale, dx, dy): boxes_all, scores_all, classes_all = [], [], [] strides = [8, 16, 32] for i, stride in enumerate(strides): reg = outputs[i * 3 + 0][0] cls = outputs[i * 3 + 1][0] obj = outputs[i * 3 + 2][0] num_classes, H, W = cls.shape for h in range(H): for w in range(W): class_prob = cls[:, h, w] cls_id = int(np.argmax(class_prob)) cls_score = class_prob[cls_id] obj_score = obj[0, h, w] score = cls_score * obj_score if score < OBJ_THRESH: continue l, t, r, b = dfl_decode(reg[:, h, w]) cx = (w + 0.5) * stride cy = (h + 0.5) * stride x1 = cx - l * stride y1 = cy - t * stride x2 = cx + r * stride y2 = cy + b * stride boxes_all.append([x1, y1, x2, y2]) scores_all.append(score) classes_all.append(cls_id) if not boxes_all: return None, None, None keep = nms(boxes_all, scores_all, NMS_THRESH) boxes = np.array(boxes_all)[keep] scores = np.array(scores_all)[keep] classes = np.array(classes_all)[keep] boxes[:, [0, 2]] = (boxes[:, [0, 2]] - dx) / scale boxes[:, [1, 3]] = (boxes[:, [1, 3]] - dy) / scale return boxes, classes, scores # ====================== RKNN 初始化(全局一次) ====================== _rknn = RKNNLite() _rknn.load_rknn(MODEL_PATH) _rknn.init_runtime() # ====================== 统一接口函数 ====================== def detect_bag(img, return_vis=False): """ Args: img (np.ndarray): BGR 原图 return_vis (bool) Returns: cls (str | None) conf (float | None) min_x (int | None) vis_img (np.ndarray) # optional """ img_r, scale, dx, dy = letterbox_resize(img, IMG_SIZE) outputs = _rknn.inference([np.expand_dims(img_r, 0)]) boxes, cls_ids, scores = post_process(outputs, scale, dx, dy) if boxes is None or len(scores) == 0: if return_vis: return None, None, None, img.copy() return None, None, None best_idx = int(np.argmax(scores)) conf = float(scores[best_idx]) cls_id = int(cls_ids[best_idx]) cls = CLASS_NAME[cls_id] x1, y1, x2, y2 = boxes[best_idx].astype(int) min_x = int(x1) if return_vis: vis = img.copy() cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText( vis, f"{cls}:{conf:.3f}", (x1, max(y1 - 5, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2 ) return cls, conf, min_x, vis return cls, conf, min_x # ====================== 测试 ====================== # ====================== 测试 ====================== if __name__ == "__main__": IMG_PATH = "./test_image/4.jpg" OUTPUT_DIR = "./result" os.makedirs(OUTPUT_DIR, exist_ok=True) img = cv2.imread(IMG_PATH) if img is None: raise FileNotFoundError(IMG_PATH) cls, conf, min_x, vis = detect_bag(img, return_vis=True) if cls is None: print("未检测到目标") else: print(f"类别: {cls}") print(f"置信度: {conf:.4f}") print(f"最左 x: {min_x}") if vis is not None: save_path = os.path.join(OUTPUT_DIR, "vis_result.jpg") cv2.imwrite(save_path, vis) print("可视化结果已保存:", save_path)