first commit
This commit is contained in:
216
class_xiantiao_rknn/main_detect.py
Normal file
216
class_xiantiao_rknn/main_detect.py
Normal file
@ -0,0 +1,216 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import platform
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
# ------------------- 全局变量 -------------------
|
||||
_global_rknn_detector = None
|
||||
labels = {0: 'hole', 1: 'crack'} # 请确保与训练时类别顺序一致
|
||||
|
||||
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
|
||||
|
||||
|
||||
# ------------------- 主机信息 -------------------
|
||||
def get_host():
|
||||
system = platform.system()
|
||||
machine = platform.machine()
|
||||
os_machine = system + '-' + machine
|
||||
if os_machine == 'Linux-aarch64':
|
||||
try:
|
||||
with open(DEVICE_COMPATIBLE_NODE) as f:
|
||||
device_compatible_str = f.read()
|
||||
if 'rk3562' in device_compatible_str:
|
||||
host = 'RK3562'
|
||||
elif 'rk3576' in device_compatible_str:
|
||||
host = 'RK3576'
|
||||
elif 'rk3588' in device_compatible_str:
|
||||
host = 'RK3588'
|
||||
else:
|
||||
host = 'RK3566_RK3568'
|
||||
except IOError:
|
||||
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
|
||||
exit(-1)
|
||||
else:
|
||||
host = os_machine
|
||||
return host
|
||||
|
||||
|
||||
# ------------------- RKNN 检测模型初始化(单例) -------------------
|
||||
def init_rknn_detector(model_path):
|
||||
global _global_rknn_detector
|
||||
if _global_rknn_detector is None:
|
||||
rknn = RKNNLite(verbose=False)
|
||||
ret = rknn.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f'Load RKNN detection model failed: {ret}')
|
||||
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f'Init RKNN runtime failed: {ret}')
|
||||
_global_rknn_detector = rknn
|
||||
print(f'[INFO] RKNN detection model loaded: {model_path}')
|
||||
return _global_rknn_detector
|
||||
|
||||
|
||||
# ------------------- 预处理:接收 RGB np.ndarray,返回模型输入 -------------------
|
||||
def preprocess_rgb_image(rgb_image, target_size=(640, 640)):
|
||||
"""
|
||||
输入: rgb_image (H, W, C) uint8, RGB 格式
|
||||
输出: (1, C, H, W) 的 np.ndarray,用于 RKNN 推理
|
||||
"""
|
||||
resized = cv2.resize(rgb_image, target_size)
|
||||
img_input = np.expand_dims(resized, 0) # (1, H, W, C)
|
||||
img_input = np.transpose(img_input, (0, 3, 1, 2)) # (1, C, H, W)
|
||||
return img_input
|
||||
|
||||
|
||||
# ------------------- 后处理 -------------------
|
||||
def postprocess_yolov8(outputs, conf_threshold=0.5, input_size=(640, 640)):
|
||||
pred = outputs[0]
|
||||
if pred.ndim == 3:
|
||||
pred = np.squeeze(pred, axis=0)
|
||||
pred = pred.T
|
||||
|
||||
boxes = pred[:, :4] # cx, cy, w, h
|
||||
scores = pred[:, 4:]
|
||||
|
||||
class_ids = np.argmax(scores, axis=1)
|
||||
confidences = np.max(scores, axis=1)
|
||||
|
||||
mask = confidences >= conf_threshold
|
||||
boxes = boxes[mask]
|
||||
confidences = confidences[mask]
|
||||
class_ids = class_ids[mask]
|
||||
|
||||
x_center, y_center, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
||||
x1 = x_center - w / 2
|
||||
y1 = y_center - h / 2
|
||||
x2 = x_center + w / 2
|
||||
y2 = y_center + h / 2
|
||||
boxes_xyxy = np.stack((x1, y1, x2, y2), axis=1)
|
||||
|
||||
detections = []
|
||||
for i in range(len(boxes_xyxy)):
|
||||
detections.append({
|
||||
'bbox': boxes_xyxy[i].tolist(),
|
||||
'confidence': float(confidences[i]),
|
||||
'class_id': int(class_ids[i])
|
||||
})
|
||||
return detections
|
||||
|
||||
|
||||
# ------------------- 坐标映射回原始图像 -------------------
|
||||
def scale_detections_to_original(detections, original_shape, input_size=(640, 640)):
|
||||
if not detections:
|
||||
return detections
|
||||
|
||||
orig_h, orig_w = original_shape[:2]
|
||||
input_w, input_h = input_size
|
||||
|
||||
scaled = []
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det['bbox']
|
||||
x1_orig = x1 * (orig_w / input_w)
|
||||
y1_orig = y1 * (orig_h / input_h)
|
||||
x2_orig = x2 * (orig_w / input_w)
|
||||
y2_orig = y2 * (orig_h / input_h)
|
||||
scaled.append({
|
||||
'bbox': [x1_orig, y1_orig, x2_orig, y2_orig],
|
||||
'confidence': det['confidence'],
|
||||
'class_id': det['class_id']
|
||||
})
|
||||
return scaled
|
||||
|
||||
|
||||
# ------------------- 主推理函数:接收 RGB np.ndarray -------------------
|
||||
def detect_quexian_from_rgb_array(
|
||||
rgb_image: np.ndarray,
|
||||
model_path: str = "xiantiao_detect.rknn",
|
||||
conf_threshold: float = 0.5,
|
||||
debug: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
输入:
|
||||
rgb_image: np.ndarray, shape (H, W, 3), dtype uint8, RGB 格式
|
||||
model_path: RKNN 模型路径
|
||||
conf_threshold: 置信度阈值
|
||||
debug: 是否显示结果
|
||||
|
||||
返回:
|
||||
bool: True 表示良品(无缺陷),False 表示有缺陷
|
||||
"""
|
||||
if rgb_image is None or rgb_image.size == 0:
|
||||
raise ValueError("输入图像为空")
|
||||
|
||||
# 初始化模型(单例)
|
||||
rknn = init_rknn_detector(model_path)
|
||||
|
||||
# 预处理
|
||||
img_input = preprocess_rgb_image(rgb_image, target_size=(640, 640))
|
||||
|
||||
# 推理
|
||||
outputs = rknn.inference(inputs=[img_input])
|
||||
|
||||
# 后处理(在 640x640 坐标系)
|
||||
detections = postprocess_yolov8(outputs, conf_threshold=conf_threshold)
|
||||
|
||||
# 映射回原始图像坐标
|
||||
detections = scale_detections_to_original(detections, rgb_image.shape, input_size=(640, 640))
|
||||
|
||||
# 每类只保留最高置信度框
|
||||
best_per_class = {}
|
||||
for det in detections:
|
||||
cls_id = det['class_id']
|
||||
if cls_id not in best_per_class or det['confidence'] > best_per_class[cls_id]['confidence']:
|
||||
best_per_class[cls_id] = det
|
||||
top_detections = list(best_per_class.values())
|
||||
|
||||
is_good = len(top_detections) == 0
|
||||
|
||||
# 调试可视化(需转回 BGR 给 cv2 显示)
|
||||
if debug:
|
||||
bgr_vis = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
|
||||
for det in top_detections:
|
||||
x1, y1, x2, y2 = map(int, det['bbox'])
|
||||
label = f"{labels.get(det['class_id'], 'unknown')} {det['confidence']:.2f}"
|
||||
cv2.rectangle(bgr_vis, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
||||
cv2.putText(bgr_vis, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
||||
|
||||
print(f"\n===== RKNN Detection Result (DEBUG) =====")
|
||||
print(f"Defect classes detected: {len(top_detections)}")
|
||||
for det in top_detections:
|
||||
cls_name = labels.get(det['class_id'], 'unknown')
|
||||
print(f" - {cls_name}: conf={det['confidence']:.3f}, bbox={det['bbox']}")
|
||||
|
||||
cv2.imshow('Detection', bgr_vis)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
return is_good
|
||||
|
||||
|
||||
# ------------------- 主函数:读图并调用 -------------------
|
||||
if __name__ == '__main__':
|
||||
image_path = "./test_image/detect/defect1.jpg"
|
||||
model_path = "xiantiao_detect.rknn"
|
||||
|
||||
# 1. 读取图像(BGR)
|
||||
bgr_image = cv2.imread(image_path)
|
||||
if bgr_image is None:
|
||||
raise RuntimeError(f"Failed to read image: {image_path}")
|
||||
|
||||
# 2. 转为 RGB(供推理函数使用)
|
||||
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 3. 调用推理函数(传入 np.ndarray)
|
||||
is_good_product = detect_quexian_from_rgb_array(
|
||||
rgb_image=rgb_image,
|
||||
model_path=model_path,
|
||||
conf_threshold=0.5,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# 4. 输出结果
|
||||
if is_good_product:
|
||||
print("✅ 产品合格(无缺陷)")
|
||||
else:
|
||||
print("❌ 产品存在缺陷")
|
||||
Reference in New Issue
Block a user