165 lines
5.7 KiB
Python
165 lines
5.7 KiB
Python
import cv2
|
||
import numpy as np
|
||
import platform
|
||
from .labels import labels # 确保这个文件存在
|
||
from rknnlite.api import RKNNLite
|
||
|
||
|
||
# ------------------- 核心:全局变量存储RKNN模型实例(确保只加载一次) -------------------
|
||
# 初始化为None,首次调用时加载模型,后续直接复用
|
||
_global_rknn_instance = None
|
||
|
||
# device tree for RK356x/RK3576/RK3588
|
||
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
|
||
|
||
def get_host():
|
||
# get platform and device type
|
||
system = platform.system()
|
||
machine = platform.machine()
|
||
os_machine = system + '-' + machine
|
||
if os_machine == 'Linux-aarch64':
|
||
try:
|
||
with open(DEVICE_COMPATIBLE_NODE) as f:
|
||
device_compatible_str = f.read()
|
||
if 'rk3562' in device_compatible_str:
|
||
host = 'RK3562'
|
||
elif 'rk3576' in device_compatible_str:
|
||
host = 'RK3576'
|
||
elif 'rk3588' in device_compatible_str:
|
||
host = 'RK3588'
|
||
else:
|
||
host = 'RK3566_RK3568'
|
||
except IOError:
|
||
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
|
||
exit(-1)
|
||
else:
|
||
host = os_machine
|
||
return host
|
||
|
||
def get_top1_class_str(result):
|
||
"""
|
||
从推理结果中提取出得分最高的类别,并返回字符串
|
||
|
||
参数:
|
||
result (list): 模型推理输出结果(格式需与原函数一致,如 [np.ndarray])
|
||
返回:
|
||
str:得分最高类别的格式化字符串
|
||
若推理失败,返回错误提示字符串
|
||
"""
|
||
if result is None:
|
||
print("Inference failed: result is None")
|
||
return
|
||
|
||
# 解析推理输出(与原逻辑一致:展平输出为1维数组)
|
||
output = result[0].reshape(-1)
|
||
|
||
# 获取得分最高的类别索引(np.argmax 直接返回最大值索引,比排序更高效)
|
||
top1_index = np.argmax(output)
|
||
|
||
# 处理标签(确保索引在 labels 列表范围内,避免越界)
|
||
if 0 <= top1_index < len(labels):
|
||
top1_class_name = labels[top1_index]
|
||
else:
|
||
top1_class_name = "Unknown Class" # 应对索引异常的边界情况
|
||
|
||
# 5. 格式化返回字符串(包含索引、得分、类别名称,得分保留6位小数)
|
||
return top1_class_name
|
||
|
||
def preprocess(raw_image, target_size=(640, 640)):
|
||
"""
|
||
读取图像并执行预处理(BGR转RGB、调整尺寸、添加Batch维度)
|
||
|
||
参数:
|
||
image_path (str): 图像文件的完整路径(如 "C:/test.jpg" 或 "/home/user/test.jpg")
|
||
target_size (tuple): 预处理后图像的目标尺寸,格式为 (width, height),默认 (640, 640)
|
||
返回:
|
||
img (numpy.ndarray): 预处理后的图像
|
||
异常:
|
||
FileNotFoundError: 图像路径不存在或无法读取时抛出
|
||
ValueError: 图像读取成功但为空(如文件损坏)时抛出
|
||
"""
|
||
# img = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
|
||
# 调整尺寸
|
||
img = cv2.resize(raw_image, target_size)
|
||
img = np.expand_dims(img, 0) # 添加batch维度
|
||
|
||
return img
|
||
|
||
# ------------------- 新增:模型初始化函数(控制只加载一次) -------------------
|
||
def init_rknn_model(model_path):
|
||
"""
|
||
初始化RKNN模型(全局唯一实例):
|
||
- 首次调用:加载模型+初始化运行时,返回模型实例
|
||
- 后续调用:直接返回已加载的全局实例,避免重复加载
|
||
"""
|
||
global _global_rknn_instance # 声明使用全局变量
|
||
|
||
# 若模型未加载过,执行加载逻辑
|
||
if _global_rknn_instance is None:
|
||
# 1. 创建RKNN实例(关闭内置日志)
|
||
rknn_lite = RKNNLite(verbose=False)
|
||
|
||
# 2. 加载RKNN模型
|
||
ret = rknn_lite.load_rknn(model_path)
|
||
if ret != 0:
|
||
print(f'[ERROR] Load CLS_RKNN model failed (code: {ret})')
|
||
exit(ret)
|
||
|
||
# 3. 初始化运行时(绑定NPU核心0)
|
||
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||
if ret != 0:
|
||
print(f'[ERROR] Init CLS_RKNN runtime failed (code: {ret})')
|
||
exit(ret)
|
||
|
||
# 4. 将加载好的实例赋值给全局变量
|
||
_global_rknn_instance = rknn_lite
|
||
print(f'[INFO] CLS_RKNN model loaded successfully (path: {model_path})')
|
||
|
||
return _global_rknn_instance
|
||
|
||
def yolov11_cls_inference(model_path, raw_image, target_size=(640, 640)):
|
||
"""
|
||
根据平台进行推理,并返回最终的分类结果
|
||
|
||
参数:
|
||
model_path (str): RKNN模型文件路径
|
||
image_path (str): 图像文件的完整路径(如 "C:/test.jpg" 或 "/home/user/test.jpg")
|
||
target_size (tuple): 预处理后图像的目标尺寸,格式为 (width, height),默认 (640, 640)
|
||
"""
|
||
rknn_model = model_path
|
||
|
||
img = preprocess(raw_image, target_size)
|
||
|
||
rknn = init_rknn_model(rknn_model)
|
||
if rknn is None:
|
||
return None, img
|
||
outputs = rknn.inference([img])
|
||
|
||
# Show the classification results
|
||
class_name = get_top1_class_str(outputs)
|
||
|
||
# rknn_lite.release()
|
||
|
||
return class_name
|
||
|
||
if __name__ == '__main__':
|
||
|
||
# 调用yolov11_cls_inference函数(target_size使用默认值640x640,也可显式传参如(112,112))
|
||
image_path = "/userdata/reenrr/inference_with_lite/cover_ready.jpg"
|
||
bgr_image = cv2.imread(image_path)
|
||
if bgr_image is None:
|
||
print(f"Failed to read image from {image_path}")
|
||
exit(-1)
|
||
|
||
rgb_frame = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
|
||
print(f"Read image from {image_path}, shape: {rgb_frame.shape}")
|
||
|
||
result = yolov11_cls_inference(
|
||
model_path="/userdata/PyQt_main_test/app/view/yolo/yolov11_cls.rknn",
|
||
raw_image=rgb_frame,
|
||
target_size=(640, 640)
|
||
)
|
||
# 打印最终结果
|
||
print(f"\n最终分类结果:{result}")
|
||
|