94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
import cv2
|
||
import numpy as np
|
||
import platform
|
||
from rknnlite.api import RKNNLite
|
||
|
||
# ------------------- 全局变量 -------------------
|
||
_global_rknn_instance = None
|
||
labels = {0: '无缺陷', 1: '有缺陷'}
|
||
|
||
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
|
||
|
||
|
||
# ------------------- 主机信息 -------------------
|
||
def get_host():
|
||
system = platform.system()
|
||
machine = platform.machine()
|
||
os_machine = system + '-' + machine
|
||
if os_machine == 'Linux-aarch64':
|
||
try:
|
||
with open(DEVICE_COMPATIBLE_NODE) as f:
|
||
device_compatible_str = f.read()
|
||
if 'rk3562' in device_compatible_str:
|
||
host = 'RK3562'
|
||
elif 'rk3576' in device_compatible_str:
|
||
host = 'RK3576'
|
||
elif 'rk3588' in device_compatible_str:
|
||
host = 'RK3588'
|
||
else:
|
||
host = 'RK3566_RK3568'
|
||
except IOError:
|
||
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
|
||
exit(-1)
|
||
else:
|
||
host = os_machine
|
||
return host
|
||
|
||
|
||
# ------------------- RKNN 模型初始化(只加载一次) -------------------
|
||
def init_rknn_model(model_path):
|
||
global _global_rknn_instance
|
||
if _global_rknn_instance is None:
|
||
rknn_lite = RKNNLite(verbose=False)
|
||
ret = rknn_lite.load_rknn(model_path)
|
||
if ret != 0:
|
||
raise RuntimeError(f'Load model failed: {ret}')
|
||
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||
if ret != 0:
|
||
raise RuntimeError(f'Init runtime failed: {ret}')
|
||
_global_rknn_instance = rknn_lite
|
||
print(f'[INFO] RKNN model loaded: {model_path}')
|
||
return _global_rknn_instance
|
||
|
||
|
||
# ------------------- 图像预处理(无 ROI,直接 resize 整图) -------------------
|
||
def preprocess(raw_image, target_size=(640, 640)):
|
||
"""
|
||
直接对整张图像 resize 到模型输入尺寸,并添加 batch 维度
|
||
"""
|
||
img_resized = cv2.resize(raw_image, target_size)
|
||
img_batch = np.expand_dims(img_resized, 0) # (H, W, C) -> (1, H, W, C)
|
||
return img_batch
|
||
|
||
|
||
# ------------------- 推理函数 -------------------
|
||
def quexian_cls_inference_once(rknn, raw_image, target_size=(640, 640)):
|
||
"""
|
||
使用已加载的 rknn 实例进行推理
|
||
返回: (class_id, boolean)
|
||
"""
|
||
img = preprocess(raw_image, target_size)
|
||
outputs = rknn.inference([img])
|
||
output = outputs[0].reshape(-1)
|
||
class_id = int(np.argmax(output))
|
||
bool_value = class_id == 1 # 1 表示“有线条”
|
||
return class_id, bool_value
|
||
|
||
|
||
# ------------------- 测试 -------------------
|
||
if __name__ == '__main__':
|
||
image_path = "./test_image/class1/2.jpg"
|
||
model_path = "xiantiao_rk3588.rknn"
|
||
|
||
bgr_image = cv2.imread(image_path)
|
||
if bgr_image is None:
|
||
raise RuntimeError(f"Failed to read image: {image_path}")
|
||
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
|
||
|
||
# 只初始化一次模型
|
||
rknn_model = init_rknn_model(model_path)
|
||
|
||
# 推理
|
||
class_id, bool_value = quexian_cls_inference_once(rknn_model, rgb_image)
|
||
print(f"类别ID: {class_id}, 布尔值: {bool_value}")
|
||
print(f"预测结果: {labels[class_id]}") |