import cv2 import numpy as np import platform from rknnlite.api import RKNNLite # ------------------- 全局变量 ------------------- _global_rknn_detector = None labels = {0: 'hole', 1: 'crack'} # 请确保与训练时类别顺序一致 DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible' # ------------------- 主机信息 ------------------- def get_host(): system = platform.system() machine = platform.machine() os_machine = system + '-' + machine if os_machine == 'Linux-aarch64': try: with open(DEVICE_COMPATIBLE_NODE) as f: device_compatible_str = f.read() if 'rk3562' in device_compatible_str: host = 'RK3562' elif 'rk3576' in device_compatible_str: host = 'RK3576' elif 'rk3588' in device_compatible_str: host = 'RK3588' else: host = 'RK3566_RK3568' except IOError: print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE)) exit(-1) else: host = os_machine return host # ------------------- RKNN 检测模型初始化(单例) ------------------- def init_rknn_detector(model_path): global _global_rknn_detector if _global_rknn_detector is None: rknn = RKNNLite(verbose=False) ret = rknn.load_rknn(model_path) if ret != 0: raise RuntimeError(f'Load RKNN detection model failed: {ret}') ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0) if ret != 0: raise RuntimeError(f'Init RKNN runtime failed: {ret}') _global_rknn_detector = rknn print(f'[INFO] RKNN detection model loaded: {model_path}') return _global_rknn_detector # ------------------- 预处理:接收 RGB np.ndarray,返回模型输入 ------------------- def preprocess_rgb_image(rgb_image, target_size=(640, 640)): """ 输入: rgb_image (H, W, C) uint8, RGB 格式 输出: (1, C, H, W) 的 np.ndarray,用于 RKNN 推理 """ resized = cv2.resize(rgb_image, target_size) img_input = np.expand_dims(resized, 0) # (1, H, W, C) img_input = np.transpose(img_input, (0, 3, 1, 2)) # (1, C, H, W) return img_input # ------------------- 后处理 ------------------- def postprocess_yolov8(outputs, conf_threshold=0.5, input_size=(640, 640)): pred = outputs[0] if pred.ndim == 3: pred = np.squeeze(pred, axis=0) pred = pred.T boxes = pred[:, :4] # cx, cy, w, h scores = pred[:, 4:] class_ids = np.argmax(scores, axis=1) confidences = np.max(scores, axis=1) mask = confidences >= conf_threshold boxes = boxes[mask] confidences = confidences[mask] class_ids = class_ids[mask] x_center, y_center, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] x1 = x_center - w / 2 y1 = y_center - h / 2 x2 = x_center + w / 2 y2 = y_center + h / 2 boxes_xyxy = np.stack((x1, y1, x2, y2), axis=1) detections = [] for i in range(len(boxes_xyxy)): detections.append({ 'bbox': boxes_xyxy[i].tolist(), 'confidence': float(confidences[i]), 'class_id': int(class_ids[i]) }) return detections # ------------------- 坐标映射回原始图像 ------------------- def scale_detections_to_original(detections, original_shape, input_size=(640, 640)): if not detections: return detections orig_h, orig_w = original_shape[:2] input_w, input_h = input_size scaled = [] for det in detections: x1, y1, x2, y2 = det['bbox'] x1_orig = x1 * (orig_w / input_w) y1_orig = y1 * (orig_h / input_h) x2_orig = x2 * (orig_w / input_w) y2_orig = y2 * (orig_h / input_h) scaled.append({ 'bbox': [x1_orig, y1_orig, x2_orig, y2_orig], 'confidence': det['confidence'], 'class_id': det['class_id'] }) return scaled # ------------------- 主推理函数:接收 RGB np.ndarray ------------------- def detect_quexian_from_rgb_array( rgb_image: np.ndarray, model_path: str = "xiantiao_detect.rknn", conf_threshold: float = 0.5, debug: bool = False ) -> bool: """ 输入: rgb_image: np.ndarray, shape (H, W, 3), dtype uint8, RGB 格式 model_path: RKNN 模型路径 conf_threshold: 置信度阈值 debug: 是否显示结果 返回: bool: True 表示良品(无缺陷),False 表示有缺陷 """ if rgb_image is None or rgb_image.size == 0: raise ValueError("输入图像为空") # 初始化模型(单例) rknn = init_rknn_detector(model_path) # 预处理 img_input = preprocess_rgb_image(rgb_image, target_size=(640, 640)) # 推理 outputs = rknn.inference(inputs=[img_input]) # 后处理(在 640x640 坐标系) detections = postprocess_yolov8(outputs, conf_threshold=conf_threshold) # 映射回原始图像坐标 detections = scale_detections_to_original(detections, rgb_image.shape, input_size=(640, 640)) # 每类只保留最高置信度框 best_per_class = {} for det in detections: cls_id = det['class_id'] if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']: best_per_class[cls_id] = det top_detections = list(best_per_class.values()) is_good = len(top_detections) == 0 # 调试可视化(需转回 BGR 给 cv2 显示) if debug: bgr_vis = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) for det in top_detections: x1, y1, x2, y2 = map(int, det['bbox']) label = f"{labels.get(det['class_id'], 'unknown')} {det['confidence']:.2f}" cv2.rectangle(bgr_vis, (x1, y1), (x2, y2), (0, 0, 255), 2) cv2.putText(bgr_vis, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) print(f"\n===== RKNN Detection Result (DEBUG) =====") print(f"Defect classes detected: {len(top_detections)}") for det in top_detections: cls_name = labels.get(det['class_id'], 'unknown') print(f" - {cls_name}: conf={det['confidence']:.3f}, bbox={det['bbox']}") cv2.imshow('Detection', bgr_vis) cv2.waitKey(0) cv2.destroyAllWindows() return is_good # ------------------- 主函数:读图并调用 ------------------- if __name__ == '__main__': image_path = "./test_image/detect/defect1.jpg" model_path = "xiantiao_detect.rknn" # 1. 读取图像(BGR) bgr_image = cv2.imread(image_path) if bgr_image is None: raise RuntimeError(f"Failed to read image: {image_path}") # 2. 转为 RGB(供推理函数使用) rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 3. 调用推理函数(传入 np.ndarray) is_good_product = detect_quexian_from_rgb_array( rgb_image=rgb_image, model_path=model_path, conf_threshold=0.5, debug=True ) # 4. 输出结果 if is_good_product: print("✅ 产品合格(无缺陷)") else: print("❌ 产品存在缺陷")