Files
zjsh_video_collection/cls_inference/cls_inference.py
2025-09-26 20:41:44 +08:00

167 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import platform
from .labels import labels # 确保这个文件存在
from rknnlite.api import RKNNLite
# ------------------- 核心全局变量存储RKNN模型实例确保只加载一次 -------------------
# 初始化为None首次调用时加载模型后续直接复用
_global_rknn_instance = None
# device tree for RK356x/RK3576/RK3588
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
def get_host():
# get platform and device type
system = platform.system()
machine = platform.machine()
os_machine = system + '-' + machine
if os_machine == 'Linux-aarch64':
try:
with open(DEVICE_COMPATIBLE_NODE) as f:
device_compatible_str = f.read()
if 'rk3562' in device_compatible_str:
host = 'RK3562'
elif 'rk3576' in device_compatible_str:
host = 'RK3576'
elif 'rk3588' in device_compatible_str:
host = 'RK3588'
else:
host = 'RK3566_RK3568'
except IOError:
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
exit(-1)
else:
host = os_machine
return host
def get_top1_class_str(result):
"""
从推理结果中提取出得分最高的类别,并返回字符串
参数:
result (list): 模型推理输出结果(格式需与原函数一致,如 [np.ndarray]
返回:
str:得分最高类别的格式化字符串
若推理失败,返回错误提示字符串
"""
if result is None:
print("Inference failed: result is None")
return
# 解析推理输出与原逻辑一致展平输出为1维数组
output = result[0].reshape(-1)
# 获取得分最高的类别索引np.argmax 直接返回最大值索引,比排序更高效)
top1_index = np.argmax(output)
# 处理标签(确保索引在 labels 列表范围内,避免越界)
if 0 <= top1_index < len(labels):
top1_class_name = labels[top1_index]
else:
top1_class_name = "Unknown Class" # 应对索引异常的边界情况
# 5. 格式化返回字符串包含索引、得分、类别名称得分保留6位小数
return top1_class_name
def preprocess(raw_image, target_size=(640, 640)):
"""
读取图像并执行预处理BGR转RGB、调整尺寸、添加Batch维度
参数:
image_path (str): 图像文件的完整路径(如 "C:/test.jpg""/home/user/test.jpg"
target_size (tuple): 预处理后图像的目标尺寸,格式为 (width, height),默认 (640, 640)
返回:
img (numpy.ndarray): 预处理后的图像
异常:
FileNotFoundError: 图像路径不存在或无法读取时抛出
ValueError: 图像读取成功但为空(如文件损坏)时抛出
"""
# img = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
# 调整尺寸
img = cv2.resize(raw_image, target_size)
img = np.expand_dims(img, 0) # 添加batch维度
return img
# ------------------- 新增:模型初始化函数(控制只加载一次) -------------------
def init_rknn_model(model_path):
"""
初始化RKNN模型全局唯一实例
- 首次调用:加载模型+初始化运行时,返回模型实例
- 后续调用:直接返回已加载的全局实例,避免重复加载
"""
global _global_rknn_instance # 声明使用全局变量
# 若模型未加载过,执行加载逻辑
if _global_rknn_instance is None:
# 1. 创建RKNN实例关闭内置日志
rknn_lite = RKNNLite(verbose=False)
# 2. 加载RKNN模型
ret = rknn_lite.load_rknn(model_path)
if ret != 0:
print(f'[ERROR] Load CLS_RKNN model failed (code: {ret})')
exit(ret)
# 3. 初始化运行时绑定NPU核心0
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
print(f'[ERROR] Init CLS_RKNN runtime failed (code: {ret})')
exit(ret)
# 4. 将加载好的实例赋值给全局变量
_global_rknn_instance = rknn_lite
print(f'[INFO] CLS_RKNN model loaded successfully (path: {model_path})')
return _global_rknn_instance
def yolov11_cls_inference(model_path, raw_image, target_size=(640, 640)):
"""
根据平台进行推理,并返回最终的分类结果
参数:
model_path (str): RKNN模型文件路径
image_path (str): 图像文件的完整路径(如 "C:/test.jpg""/home/user/test.jpg"
target_size (tuple): 预处理后图像的目标尺寸,格式为 (width, height),默认 (640, 640)
"""
rknn_model = model_path
img = preprocess(raw_image, target_size)
# 只加载一次模型,避免重复加载
rknn = init_rknn_model(rknn_model)
if rknn is None:
return None, img
outputs = rknn.inference([img])
# Show the classification results
class_name = get_top1_class_str(outputs)
# rknn_lite.release()
return class_name
if __name__ == '__main__':
# 调用yolov11_cls_inference函数target_size使用默认值640x640也可显式传参如(112,112)
image_path = "/userdata/reenrr/inference_with_lite/cover_ready.jpg"
bgr_image = cv2.imread(image_path)
if bgr_image is None:
print(f"Failed to read image from {image_path}")
exit(-1)
rgb_frame = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
print(f"Read image from {image_path}, shape: {rgb_frame.shape}")
result = yolov11_cls_inference(
model_path="/userdata/PyQt_main_test/app/view/yolo/yolov11_cls.rknn",
raw_image=rgb_frame,
target_size=(640, 640)
)
# 打印最终结果
print(f"\n最终分类结果:{result}")