Files
xiantiao_CV/class_xiantiao_rknn/main_detect.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

216 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import platform
from rknnlite.api import RKNNLite
# ------------------- 全局变量 -------------------
_global_rknn_detector = None
labels = {0: 'hole', 1: 'crack'} # 请确保与训练时类别顺序一致
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
# ------------------- 主机信息 -------------------
def get_host():
system = platform.system()
machine = platform.machine()
os_machine = system + '-' + machine
if os_machine == 'Linux-aarch64':
try:
with open(DEVICE_COMPATIBLE_NODE) as f:
device_compatible_str = f.read()
if 'rk3562' in device_compatible_str:
host = 'RK3562'
elif 'rk3576' in device_compatible_str:
host = 'RK3576'
elif 'rk3588' in device_compatible_str:
host = 'RK3588'
else:
host = 'RK3566_RK3568'
except IOError:
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
exit(-1)
else:
host = os_machine
return host
# ------------------- RKNN 检测模型初始化(单例) -------------------
def init_rknn_detector(model_path):
global _global_rknn_detector
if _global_rknn_detector is None:
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f'Load RKNN detection model failed: {ret}')
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f'Init RKNN runtime failed: {ret}')
_global_rknn_detector = rknn
print(f'[INFO] RKNN detection model loaded: {model_path}')
return _global_rknn_detector
# ------------------- 预处理:接收 RGB np.ndarray返回模型输入 -------------------
def preprocess_rgb_image(rgb_image, target_size=(640, 640)):
"""
输入: rgb_image (H, W, C) uint8, RGB 格式
输出: (1, C, H, W) 的 np.ndarray用于 RKNN 推理
"""
resized = cv2.resize(rgb_image, target_size)
img_input = np.expand_dims(resized, 0) # (1, H, W, C)
img_input = np.transpose(img_input, (0, 3, 1, 2)) # (1, C, H, W)
return img_input
# ------------------- 后处理 -------------------
def postprocess_yolov8(outputs, conf_threshold=0.5, input_size=(640, 640)):
pred = outputs[0]
if pred.ndim == 3:
pred = np.squeeze(pred, axis=0)
pred = pred.T
boxes = pred[:, :4] # cx, cy, w, h
scores = pred[:, 4:]
class_ids = np.argmax(scores, axis=1)
confidences = np.max(scores, axis=1)
mask = confidences >= conf_threshold
boxes = boxes[mask]
confidences = confidences[mask]
class_ids = class_ids[mask]
x_center, y_center, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
x1 = x_center - w / 2
y1 = y_center - h / 2
x2 = x_center + w / 2
y2 = y_center + h / 2
boxes_xyxy = np.stack((x1, y1, x2, y2), axis=1)
detections = []
for i in range(len(boxes_xyxy)):
detections.append({
'bbox': boxes_xyxy[i].tolist(),
'confidence': float(confidences[i]),
'class_id': int(class_ids[i])
})
return detections
# ------------------- 坐标映射回原始图像 -------------------
def scale_detections_to_original(detections, original_shape, input_size=(640, 640)):
if not detections:
return detections
orig_h, orig_w = original_shape[:2]
input_w, input_h = input_size
scaled = []
for det in detections:
x1, y1, x2, y2 = det['bbox']
x1_orig = x1 * (orig_w / input_w)
y1_orig = y1 * (orig_h / input_h)
x2_orig = x2 * (orig_w / input_w)
y2_orig = y2 * (orig_h / input_h)
scaled.append({
'bbox': [x1_orig, y1_orig, x2_orig, y2_orig],
'confidence': det['confidence'],
'class_id': det['class_id']
})
return scaled
# ------------------- 主推理函数:接收 RGB np.ndarray -------------------
def detect_quexian_from_rgb_array(
rgb_image: np.ndarray,
model_path: str = "xiantiao_detect.rknn",
conf_threshold: float = 0.5,
debug: bool = False
) -> bool:
"""
输入:
rgb_image: np.ndarray, shape (H, W, 3), dtype uint8, RGB 格式
model_path: RKNN 模型路径
conf_threshold: 置信度阈值
debug: 是否显示结果
返回:
bool: True 表示良品无缺陷False 表示有缺陷
"""
if rgb_image is None or rgb_image.size == 0:
raise ValueError("输入图像为空")
# 初始化模型(单例)
rknn = init_rknn_detector(model_path)
# 预处理
img_input = preprocess_rgb_image(rgb_image, target_size=(640, 640))
# 推理
outputs = rknn.inference(inputs=[img_input])
# 后处理(在 640x640 坐标系)
detections = postprocess_yolov8(outputs, conf_threshold=conf_threshold)
# 映射回原始图像坐标
detections = scale_detections_to_original(detections, rgb_image.shape, input_size=(640, 640))
# 每类只保留最高置信度框
best_per_class = {}
for det in detections:
cls_id = det['class_id']
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
best_per_class[cls_id] = det
top_detections = list(best_per_class.values())
is_good = len(top_detections) == 0
# 调试可视化(需转回 BGR 给 cv2 显示)
if debug:
bgr_vis = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
for det in top_detections:
x1, y1, x2, y2 = map(int, det['bbox'])
label = f"{labels.get(det['class_id'], 'unknown')} {det['confidence']:.2f}"
cv2.rectangle(bgr_vis, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(bgr_vis, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
print(f"\n===== RKNN Detection Result (DEBUG) =====")
print(f"Defect classes detected: {len(top_detections)}")
for det in top_detections:
cls_name = labels.get(det['class_id'], 'unknown')
print(f" - {cls_name}: conf={det['confidence']:.3f}, bbox={det['bbox']}")
cv2.imshow('Detection', bgr_vis)
cv2.waitKey(0)
cv2.destroyAllWindows()
return is_good
# ------------------- 主函数:读图并调用 -------------------
if __name__ == '__main__':
image_path = "./test_image/detect/defect1.jpg"
model_path = "xiantiao_detect.rknn"
# 1. 读取图像BGR
bgr_image = cv2.imread(image_path)
if bgr_image is None:
raise RuntimeError(f"Failed to read image: {image_path}")
# 2. 转为 RGB供推理函数使用
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
# 3. 调用推理函数(传入 np.ndarray
is_good_product = detect_quexian_from_rgb_array(
rgb_image=rgb_image,
model_path=model_path,
conf_threshold=0.5,
debug=True
)
# 4. 输出结果
if is_good_product:
print("✅ 产品合格(无缺陷)")
else:
print("❌ 产品存在缺陷")