Files
xiantiao_CV/class_xiantiao_rknn/main_cls.py

100 lines
3.2 KiB
Python
Raw Normal View History

2025-12-16 15:12:02 +08:00
import cv2
import numpy as np
import platform
from rknnlite.api import RKNNLite
# ------------------- 全局变量 -------------------
_global_rknn_instance = None
labels = {0: '无线条', 1: '有线条'}
# ROI: x, y, w, h
ROI = (3,0,694,182)
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
# ------------------- 主机信息 -------------------
def get_host():
system = platform.system()
machine = platform.machine()
os_machine = system + '-' + machine
if os_machine == 'Linux-aarch64':
try:
with open(DEVICE_COMPATIBLE_NODE) as f:
device_compatible_str = f.read()
if 'rk3562' in device_compatible_str:
host = 'RK3562'
elif 'rk3576' in device_compatible_str:
host = 'RK3576'
elif 'rk3588' in device_compatible_str:
host = 'RK3588'
else:
host = 'RK3566_RK3568'
except IOError:
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
exit(-1)
else:
host = os_machine
return host
# ------------------- RKNN 模型初始化(只加载一次) -------------------
def init_rknn_model(model_path):
global _global_rknn_instance
if _global_rknn_instance is None:
rknn_lite = RKNNLite(verbose=False)
ret = rknn_lite.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f'Load model failed: {ret}')
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f'Init runtime failed: {ret}')
_global_rknn_instance = rknn_lite
print(f'[INFO] RKNN model loaded: {model_path}')
return _global_rknn_instance
# ------------------- 图像预处理 + ROI 裁剪 -------------------
def preprocess(raw_image, target_size=(640, 640)):
"""
ROI 裁剪 + resize + batch 维度
"""
global ROI
x, y, w, h = ROI
roi_img = raw_image[y:y+h, x:x+w]
img_resized = cv2.resize(roi_img, target_size)
img_batch = np.expand_dims(img_resized, 0) # 添加 batch 维度
return img_batch
# ------------------- 推理函数 -------------------
def xiantiao_cls_inference_once(rknn, raw_image, target_size=(640, 640)):
"""
使用已加载的 rknn 实例进行推理
返回: (class_id, boolean)
"""
img = preprocess(raw_image, target_size)
outputs = rknn.inference([img])
output = outputs[0].reshape(-1)
class_id = int(np.argmax(output))
bool_value = class_id == 1
return class_id, bool_value
# ------------------- 测试 -------------------
if __name__ == '__main__':
image_path = "./test_image/class1/2.jpg"
model_path = "xiantiao_rk3588.rknn"
bgr_image = cv2.imread(image_path)
if bgr_image is None:
raise RuntimeError(f"Failed to read image: {image_path}")
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
# 只初始化一次模型
rknn_model = init_rknn_model(model_path)
# 多次调用都用同一个 rknn_model
class_id, bool_value = xiantiao_cls_inference_once(rknn_model, rgb_image)
print(f"类别ID: {class_id}, 布尔值: {bool_value}")