Files
xiantiao_CV/class_xiantiao_rknn/main_quexian_cls.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

94 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import platform
from rknnlite.api import RKNNLite
# ------------------- 全局变量 -------------------
_global_rknn_instance = None
labels = {0: '无缺陷', 1: '有缺陷'}
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
# ------------------- 主机信息 -------------------
def get_host():
system = platform.system()
machine = platform.machine()
os_machine = system + '-' + machine
if os_machine == 'Linux-aarch64':
try:
with open(DEVICE_COMPATIBLE_NODE) as f:
device_compatible_str = f.read()
if 'rk3562' in device_compatible_str:
host = 'RK3562'
elif 'rk3576' in device_compatible_str:
host = 'RK3576'
elif 'rk3588' in device_compatible_str:
host = 'RK3588'
else:
host = 'RK3566_RK3568'
except IOError:
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
exit(-1)
else:
host = os_machine
return host
# ------------------- RKNN 模型初始化(只加载一次) -------------------
def init_rknn_model(model_path):
global _global_rknn_instance
if _global_rknn_instance is None:
rknn_lite = RKNNLite(verbose=False)
ret = rknn_lite.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f'Load model failed: {ret}')
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f'Init runtime failed: {ret}')
_global_rknn_instance = rknn_lite
print(f'[INFO] RKNN model loaded: {model_path}')
return _global_rknn_instance
# ------------------- 图像预处理(无 ROI直接 resize 整图) -------------------
def preprocess(raw_image, target_size=(640, 640)):
"""
直接对整张图像 resize 到模型输入尺寸,并添加 batch 维度
"""
img_resized = cv2.resize(raw_image, target_size)
img_batch = np.expand_dims(img_resized, 0) # (H, W, C) -> (1, H, W, C)
return img_batch
# ------------------- 推理函数 -------------------
def quexian_cls_inference_once(rknn, raw_image, target_size=(640, 640)):
"""
使用已加载的 rknn 实例进行推理
返回: (class_id, boolean)
"""
img = preprocess(raw_image, target_size)
outputs = rknn.inference([img])
output = outputs[0].reshape(-1)
class_id = int(np.argmax(output))
bool_value = class_id == 1 # 1 表示“有线条”
return class_id, bool_value
# ------------------- 测试 -------------------
if __name__ == '__main__':
image_path = "./test_image/class1/2.jpg"
model_path = "xiantiao_rk3588.rknn"
bgr_image = cv2.imread(image_path)
if bgr_image is None:
raise RuntimeError(f"Failed to read image: {image_path}")
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
# 只初始化一次模型
rknn_model = init_rknn_model(model_path)
# 推理
class_id, bool_value = quexian_cls_inference_once(rknn_model, rgb_image)
print(f"类别ID: {class_id}, 布尔值: {bool_value}")
print(f"预测结果: {labels[class_id]}")