分类部署例程
This commit is contained in:
3
.idea/.gitignore
generated
vendored
Normal file
3
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
8
.idea/inference_with_lite.iml
generated
Normal file
8
.idea/inference_with_lite.iml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="image-classification-flower" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
18
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
18
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="5">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="scipy" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="snap7" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="jsonchema" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="werkzeung" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="image-classification-flower" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="image-classification-flower" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/inference_with_lite.iml" filepath="$PROJECT_DIR$/.idea/inference_with_lite.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
164
cls_inference.py
Normal file
164
cls_inference.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import platform
|
||||||
|
from .labels import labels # 确保这个文件存在
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- 核心:全局变量存储RKNN模型实例(确保只加载一次) -------------------
|
||||||
|
# 初始化为None,首次调用时加载模型,后续直接复用
|
||||||
|
_global_rknn_instance = None
|
||||||
|
|
||||||
|
# device tree for RK356x/RK3576/RK3588
|
||||||
|
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
|
||||||
|
|
||||||
|
def get_host():
|
||||||
|
# get platform and device type
|
||||||
|
system = platform.system()
|
||||||
|
machine = platform.machine()
|
||||||
|
os_machine = system + '-' + machine
|
||||||
|
if os_machine == 'Linux-aarch64':
|
||||||
|
try:
|
||||||
|
with open(DEVICE_COMPATIBLE_NODE) as f:
|
||||||
|
device_compatible_str = f.read()
|
||||||
|
if 'rk3562' in device_compatible_str:
|
||||||
|
host = 'RK3562'
|
||||||
|
elif 'rk3576' in device_compatible_str:
|
||||||
|
host = 'RK3576'
|
||||||
|
elif 'rk3588' in device_compatible_str:
|
||||||
|
host = 'RK3588'
|
||||||
|
else:
|
||||||
|
host = 'RK3566_RK3568'
|
||||||
|
except IOError:
|
||||||
|
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
|
||||||
|
exit(-1)
|
||||||
|
else:
|
||||||
|
host = os_machine
|
||||||
|
return host
|
||||||
|
|
||||||
|
def get_top1_class_str(result):
|
||||||
|
"""
|
||||||
|
从推理结果中提取出得分最高的类别,并返回字符串
|
||||||
|
|
||||||
|
参数:
|
||||||
|
result (list): 模型推理输出结果(格式需与原函数一致,如 [np.ndarray])
|
||||||
|
返回:
|
||||||
|
str:得分最高类别的格式化字符串
|
||||||
|
若推理失败,返回错误提示字符串
|
||||||
|
"""
|
||||||
|
if result is None:
|
||||||
|
print("Inference failed: result is None")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解析推理输出(与原逻辑一致:展平输出为1维数组)
|
||||||
|
output = result[0].reshape(-1)
|
||||||
|
|
||||||
|
# 获取得分最高的类别索引(np.argmax 直接返回最大值索引,比排序更高效)
|
||||||
|
top1_index = np.argmax(output)
|
||||||
|
|
||||||
|
# 处理标签(确保索引在 labels 列表范围内,避免越界)
|
||||||
|
if 0 <= top1_index < len(labels):
|
||||||
|
top1_class_name = labels[top1_index]
|
||||||
|
else:
|
||||||
|
top1_class_name = "Unknown Class" # 应对索引异常的边界情况
|
||||||
|
|
||||||
|
# 5. 格式化返回字符串(包含索引、得分、类别名称,得分保留6位小数)
|
||||||
|
return top1_class_name
|
||||||
|
|
||||||
|
def preprocess(raw_image, target_size=(640, 640)):
|
||||||
|
"""
|
||||||
|
读取图像并执行预处理(BGR转RGB、调整尺寸、添加Batch维度)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
image_path (str): 图像文件的完整路径(如 "C:/test.jpg" 或 "/home/user/test.jpg")
|
||||||
|
target_size (tuple): 预处理后图像的目标尺寸,格式为 (width, height),默认 (640, 640)
|
||||||
|
返回:
|
||||||
|
img (numpy.ndarray): 预处理后的图像
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 图像路径不存在或无法读取时抛出
|
||||||
|
ValueError: 图像读取成功但为空(如文件损坏)时抛出
|
||||||
|
"""
|
||||||
|
# img = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
|
||||||
|
# 调整尺寸
|
||||||
|
img = cv2.resize(raw_image, target_size)
|
||||||
|
img = np.expand_dims(img, 0) # 添加batch维度
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
# ------------------- 新增:模型初始化函数(控制只加载一次) -------------------
|
||||||
|
def init_rknn_model(model_path):
|
||||||
|
"""
|
||||||
|
初始化RKNN模型(全局唯一实例):
|
||||||
|
- 首次调用:加载模型+初始化运行时,返回模型实例
|
||||||
|
- 后续调用:直接返回已加载的全局实例,避免重复加载
|
||||||
|
"""
|
||||||
|
global _global_rknn_instance # 声明使用全局变量
|
||||||
|
|
||||||
|
# 若模型未加载过,执行加载逻辑
|
||||||
|
if _global_rknn_instance is None:
|
||||||
|
# 1. 创建RKNN实例(关闭内置日志)
|
||||||
|
rknn_lite = RKNNLite(verbose=False)
|
||||||
|
|
||||||
|
# 2. 加载RKNN模型
|
||||||
|
ret = rknn_lite.load_rknn(model_path)
|
||||||
|
if ret != 0:
|
||||||
|
print(f'[ERROR] Load CLS_RKNN model failed (code: {ret})')
|
||||||
|
exit(ret)
|
||||||
|
|
||||||
|
# 3. 初始化运行时(绑定NPU核心0)
|
||||||
|
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
if ret != 0:
|
||||||
|
print(f'[ERROR] Init CLS_RKNN runtime failed (code: {ret})')
|
||||||
|
exit(ret)
|
||||||
|
|
||||||
|
# 4. 将加载好的实例赋值给全局变量
|
||||||
|
_global_rknn_instance = rknn_lite
|
||||||
|
print(f'[INFO] CLS_RKNN model loaded successfully (path: {model_path})')
|
||||||
|
|
||||||
|
return _global_rknn_instance
|
||||||
|
|
||||||
|
def yolov11_cls_inference(model_path, raw_image, target_size=(640, 640)):
|
||||||
|
"""
|
||||||
|
根据平台进行推理,并返回最终的分类结果
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model_path (str): RKNN模型文件路径
|
||||||
|
image_path (str): 图像文件的完整路径(如 "C:/test.jpg" 或 "/home/user/test.jpg")
|
||||||
|
target_size (tuple): 预处理后图像的目标尺寸,格式为 (width, height),默认 (640, 640)
|
||||||
|
"""
|
||||||
|
rknn_model = model_path
|
||||||
|
|
||||||
|
img = preprocess(raw_image, target_size)
|
||||||
|
|
||||||
|
rknn = init_rknn_model(rknn_model)
|
||||||
|
if rknn is None:
|
||||||
|
return None, img
|
||||||
|
outputs = rknn.inference([img])
|
||||||
|
|
||||||
|
# Show the classification results
|
||||||
|
class_name = get_top1_class_str(outputs)
|
||||||
|
|
||||||
|
# rknn_lite.release()
|
||||||
|
|
||||||
|
return class_name
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# 调用yolov11_cls_inference函数(target_size使用默认值640x640,也可显式传参如(112,112))
|
||||||
|
image_path = "/userdata/reenrr/inference_with_lite/cover_ready.jpg"
|
||||||
|
bgr_image = cv2.imread(image_path)
|
||||||
|
if bgr_image is None:
|
||||||
|
print(f"Failed to read image from {image_path}")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
rgb_frame = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
|
||||||
|
print(f"Read image from {image_path}, shape: {rgb_frame.shape}")
|
||||||
|
|
||||||
|
result = yolov11_cls_inference(
|
||||||
|
model_path="/userdata/PyQt_main_test/app/view/yolo/yolov11_cls.rknn",
|
||||||
|
raw_image=rgb_frame,
|
||||||
|
target_size=(640, 640)
|
||||||
|
)
|
||||||
|
# 打印最终结果
|
||||||
|
print(f"\n最终分类结果:{result}")
|
||||||
|
|
||||||
BIN
cover_noready.jpg
Normal file
BIN
cover_noready.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 63 KiB |
BIN
cover_noready1.jpg
Normal file
BIN
cover_noready1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 598 KiB |
BIN
cover_ready.jpg
Normal file
BIN
cover_ready.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 60 KiB |
6
labels.py
Normal file
6
labels.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# the labels come from synset.txt, download link: https://s3.amazonaws.com/onnx-model-zoo/synset.txt
|
||||||
|
|
||||||
|
labels = \
|
||||||
|
{0: 'cover_noready',
|
||||||
|
1: 'cover_ready'
|
||||||
|
}
|
||||||
BIN
mobilenetv2_224.rknn
Normal file
BIN
mobilenetv2_224.rknn
Normal file
Binary file not shown.
BIN
mobilenetv2_640.rknn
Normal file
BIN
mobilenetv2_640.rknn
Normal file
Binary file not shown.
124
mobilenetv2_inference.py
Normal file
124
mobilenetv2_inference.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import platform
|
||||||
|
from labels import labels # 确保这个文件存在
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
import time
|
||||||
|
|
||||||
|
model_path = '/userdata/reenrr/inference_with_lite/mobilenetv2_640.rknn'
|
||||||
|
image_path = '/userdata/reenrr/inference_with_lite/222.jpg'
|
||||||
|
target_size = (640, 640)
|
||||||
|
|
||||||
|
# device tree for RK356x/RK3576/RK3588
|
||||||
|
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
|
||||||
|
|
||||||
|
def get_host():
|
||||||
|
# get platform and device type
|
||||||
|
system = platform.system()
|
||||||
|
machine = platform.machine()
|
||||||
|
os_machine = system + '-' + machine
|
||||||
|
if os_machine == 'Linux-aarch64':
|
||||||
|
try:
|
||||||
|
with open(DEVICE_COMPATIBLE_NODE) as f:
|
||||||
|
device_compatible_str = f.read()
|
||||||
|
if 'rk3562' in device_compatible_str:
|
||||||
|
host = 'RK3562'
|
||||||
|
elif 'rk3576' in device_compatible_str:
|
||||||
|
host = 'RK3576'
|
||||||
|
elif 'rk3588' in device_compatible_str:
|
||||||
|
host = 'RK3588'
|
||||||
|
else:
|
||||||
|
host = 'RK3566_RK3568'
|
||||||
|
except IOError:
|
||||||
|
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
|
||||||
|
exit(-1)
|
||||||
|
else:
|
||||||
|
host = os_machine
|
||||||
|
return host
|
||||||
|
|
||||||
|
# 模型路径配置
|
||||||
|
RK3566_RK3568_RKNN_MODEL = 'resnet18_for_rk3566_rk3568.rknn'
|
||||||
|
RK3588_RKNN_MODEL = model_path
|
||||||
|
RK3562_RKNN_MODEL = 'resnet18_for_rk3562.rknn'
|
||||||
|
RK3576_RKNN_MODEL = 'resnet18_for_rk3576.rknn'
|
||||||
|
|
||||||
|
def show_top5(result):
|
||||||
|
if result is None:
|
||||||
|
print("Inference failed: result is None")
|
||||||
|
return
|
||||||
|
|
||||||
|
output = result[0].reshape(-1)
|
||||||
|
# Softmax
|
||||||
|
output = np.exp(output) / np.sum(np.exp(output))
|
||||||
|
# Get the indices of the top 5 largest values
|
||||||
|
output_sorted_indices = np.argsort(output)[::-1][:5]
|
||||||
|
top5_str = 'resnet18\n-----TOP 5-----\n'
|
||||||
|
for i, index in enumerate(output_sorted_indices):
|
||||||
|
value = output[index]
|
||||||
|
if value > 0:
|
||||||
|
topi = '[{:>3d}] score:{:.6f} class:"{}"\n'.format(index, value, labels[index])
|
||||||
|
else:
|
||||||
|
topi = '-1: 0.0\n'
|
||||||
|
top5_str += topi
|
||||||
|
print(top5_str)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Get device information
|
||||||
|
host_name = get_host()
|
||||||
|
if host_name == 'RK3566_RK3568':
|
||||||
|
rknn_model = RK3566_RK3568_RKNN_MODEL
|
||||||
|
elif host_name == 'RK3562':
|
||||||
|
rknn_model = RK3562_RKNN_MODEL
|
||||||
|
elif host_name == 'RK3576':
|
||||||
|
rknn_model = RK3576_RKNN_MODEL
|
||||||
|
elif host_name == 'RK3588':
|
||||||
|
rknn_model = RK3588_RKNN_MODEL
|
||||||
|
else:
|
||||||
|
print("This demo cannot run on the current platform: {}".format(host_name))
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
rknn_lite = RKNNLite()
|
||||||
|
|
||||||
|
# Load RKNN model
|
||||||
|
print('--> Load RKNN model')
|
||||||
|
ret = rknn_lite.load_rknn(rknn_model)
|
||||||
|
if ret != 0:
|
||||||
|
print('Load RKNN model failed')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# 读取并预处理图像 - 这是关键修改部分
|
||||||
|
ori_img = cv2.imread(image_path)
|
||||||
|
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
|
||||||
|
# 调整尺寸
|
||||||
|
img = cv2.resize(img, target_size)
|
||||||
|
img = np.expand_dims(img, 0) # 添加batch维度
|
||||||
|
|
||||||
|
# Init runtime environment
|
||||||
|
print('--> Init runtime environment')
|
||||||
|
if host_name in ['RK3576', 'RK3588']:
|
||||||
|
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
else:
|
||||||
|
ret = rknn_lite.init_runtime()
|
||||||
|
if ret != 0:
|
||||||
|
print('Init runtime environment failed')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
print("host_name:", host_name)
|
||||||
|
print("RKNNLite.NPU_CORE_0:", RKNNLite.NPU_CORE_0)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
print('--> Running model')
|
||||||
|
start_time = time.time()*1000 # 转为毫秒
|
||||||
|
outputs = rknn_lite.inference(inputs=[img])
|
||||||
|
end_time = time.time()*1000
|
||||||
|
print("outputs:", outputs)
|
||||||
|
print('Inference completed')
|
||||||
|
print("inference_time:", end_time-start_time,"ms")
|
||||||
|
|
||||||
|
# Show the classification results
|
||||||
|
show_top5(outputs)
|
||||||
|
|
||||||
|
rknn_lite.release()
|
||||||
19
output_shape.py
Normal file
19
output_shape.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model_path = '/userdata/reenrr/inference_with_lite/mobilenetv2_640.rknn'
|
||||||
|
|
||||||
|
rknn_lite = RKNNLite()
|
||||||
|
rknn_lite.load_rknn(model_path)
|
||||||
|
rknn_lite.init_runtime()
|
||||||
|
|
||||||
|
# 通过实际推理获取输出维度
|
||||||
|
dummy_input = np.random.randn(1, 3, 640, 640).astype(np.float32) # 根据模型输入尺寸调整
|
||||||
|
outputs = rknn_lite.inference(inputs=[dummy_input])
|
||||||
|
|
||||||
|
print("\n输出维度信息:")
|
||||||
|
for i, out in enumerate(outputs):
|
||||||
|
print(f"Output {i} shape: {out.shape}") # 查看输出形状
|
||||||
|
|
||||||
|
rknn_lite.release()
|
||||||
|
|
||||||
6
readme.md
Normal file
6
readme.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# 使用说明
|
||||||
|
方法一:修改cls_inference.py(替换图片和rknn模型路径和输入图片大小)和labels.py(分类标签名)
|
||||||
|
(只调用一次RKNN模型)
|
||||||
|
|
||||||
|
方法二:修改yolov11_cls_inference.py(替换图片和rknn模型路径和输入图片大小)和labels.py(分类标签名)
|
||||||
|
(每次调用RKNN模型)
|
||||||
BIN
yolov11_cls.rknn
Normal file
BIN
yolov11_cls.rknn
Normal file
Binary file not shown.
120
yolov11_cls_inference.py
Normal file
120
yolov11_cls_inference.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import platform
|
||||||
|
from labels import labels # 确保这个文件存在
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
|
||||||
|
model_path = '/userdata/reenrr/inference_with_lite/yolov11_cls.rknn'
|
||||||
|
image_path = '/userdata/reenrr/inference_with_lite/222.jpg'
|
||||||
|
target_size = (640, 640)
|
||||||
|
|
||||||
|
# device tree for RK356x/RK3576/RK3588
|
||||||
|
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
|
||||||
|
|
||||||
|
def get_host():
|
||||||
|
# get platform and device type
|
||||||
|
system = platform.system()
|
||||||
|
machine = platform.machine()
|
||||||
|
os_machine = system + '-' + machine
|
||||||
|
if os_machine == 'Linux-aarch64':
|
||||||
|
try:
|
||||||
|
with open(DEVICE_COMPATIBLE_NODE) as f:
|
||||||
|
device_compatible_str = f.read()
|
||||||
|
if 'rk3562' in device_compatible_str:
|
||||||
|
host = 'RK3562'
|
||||||
|
elif 'rk3576' in device_compatible_str:
|
||||||
|
host = 'RK3576'
|
||||||
|
elif 'rk3588' in device_compatible_str:
|
||||||
|
host = 'RK3588'
|
||||||
|
else:
|
||||||
|
host = 'RK3566_RK3568'
|
||||||
|
except IOError:
|
||||||
|
print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
|
||||||
|
exit(-1)
|
||||||
|
else:
|
||||||
|
host = os_machine
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
RK3566_RK3568_RKNN_MODEL = 'resnet18_for_rk3566_rk3568.rknn'
|
||||||
|
RK3588_RKNN_MODEL = model_path
|
||||||
|
RK3562_RKNN_MODEL = 'resnet18_for_rk3562.rknn'
|
||||||
|
RK3576_RKNN_MODEL = 'resnet18_for_rk3576.rknn'
|
||||||
|
|
||||||
|
def show_top5(result):
|
||||||
|
if result is None:
|
||||||
|
print("Inference failed: result is None")
|
||||||
|
return
|
||||||
|
|
||||||
|
output = result[0].reshape(-1)
|
||||||
|
# Softmax
|
||||||
|
# output = np.exp(output) / np.sum(np.exp(output))
|
||||||
|
# Get the indices of the top 5 largest values
|
||||||
|
output_sorted_indices = np.argsort(output)[::-1][:5]
|
||||||
|
top5_str = 'resnet18\n-----TOP 5-----\n'
|
||||||
|
for i, index in enumerate(output_sorted_indices):
|
||||||
|
value = output[index]
|
||||||
|
if value > 0:
|
||||||
|
topi = '[{:>3d}] score:{:.6f} class:"{}"\n'.format(index, value, labels[index])
|
||||||
|
else:
|
||||||
|
topi = '-1: 0.0\n'
|
||||||
|
top5_str += topi
|
||||||
|
print(top5_str)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Get device information
|
||||||
|
host_name = get_host()
|
||||||
|
if host_name == 'RK3566_RK3568':
|
||||||
|
rknn_model = RK3566_RK3568_RKNN_MODEL
|
||||||
|
elif host_name == 'RK3562':
|
||||||
|
rknn_model = RK3562_RKNN_MODEL
|
||||||
|
elif host_name == 'RK3576':
|
||||||
|
rknn_model = RK3576_RKNN_MODEL
|
||||||
|
elif host_name == 'RK3588':
|
||||||
|
rknn_model = RK3588_RKNN_MODEL
|
||||||
|
else:
|
||||||
|
print("This demo cannot run on the current platform: {}".format(host_name))
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
rknn_lite = RKNNLite()
|
||||||
|
|
||||||
|
# Load RKNN model
|
||||||
|
print('--> Load RKNN model')
|
||||||
|
ret = rknn_lite.load_rknn(rknn_model)
|
||||||
|
if ret != 0:
|
||||||
|
print('Load RKNN model failed')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
# 读取并预处理图像 - 这是关键修改部分
|
||||||
|
ori_img = cv2.imread(image_path)
|
||||||
|
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
|
||||||
|
# 调整尺寸
|
||||||
|
img = cv2.resize(img, target_size)
|
||||||
|
img = np.expand_dims(img, 0) # 添加batch维度
|
||||||
|
|
||||||
|
# Init runtime environment
|
||||||
|
print('--> Init runtime environment')
|
||||||
|
if host_name in ['RK3576', 'RK3588']:
|
||||||
|
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
else:
|
||||||
|
ret = rknn_lite.init_runtime()
|
||||||
|
if ret != 0:
|
||||||
|
print('Init runtime environment failed')
|
||||||
|
exit(ret)
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
print("host_name:", host_name)
|
||||||
|
print("RKNNLite.NPU_CORE_0:", RKNNLite.NPU_CORE_0)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
print('--> Running model')
|
||||||
|
outputs = rknn_lite.inference(inputs=[img])
|
||||||
|
print("outputs:", outputs)
|
||||||
|
print('Inference completed')
|
||||||
|
|
||||||
|
# Show the classification results
|
||||||
|
show_top5(outputs)
|
||||||
|
|
||||||
|
rknn_lite.release()
|
||||||
Reference in New Issue
Block a user