Files
zjsh_code_jicheng/charge_3cls/charge_cls_rknn.py
2026-03-10 16:51:57 +08:00

199 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from rknnlite.api import RKNNLite
from collections import deque
class StableClassJudge:
"""
连续三帧稳定判决器:
- class0 / class1 连续 3 帧 -> 输出
- class2 -> 清空计数,重新统计
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
"""
输入单帧分类结果
返回:
- None尚未稳定
- class_id稳定输出结果
"""
# 遇到 class2直接清空重新计数
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
# 缓冲未满
if len(self.buffer) < self.stable_frames:
return None
# 三帧完全一致
if len(set(self.buffer)) == 1:
stable_class = self.buffer[0]
self.reset() # 输出一次后重新计数(防止重复触发)
return stable_class
return None
# ---------------------------
# 三分类映射
# ---------------------------
CLASS_NAMES = {
0: "插好",
1: "未插好",
2: "有遮挡"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
def init_rknn_model(model_path):
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功:{model_path}")
return rknn
# ---------------------------
# 预处理
# ---------------------------
def letterbox(image, new_size=640, color=(114, 114, 114)):
h, w = image.shape[:2]
scale = min(new_size / h, new_size / w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
top = (new_size - nh) // 2
left = (new_size - nw) // 2
new_img[top:top + nh, left:left + nw] = resized
return new_img
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(
img,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if resize_mode == "letterbox":
img_box = letterbox(img, new_size=size)
else:
img_box = resize_stretch(img, size=size)
if to_rgb:
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
img_f = img_box.astype(np.float32)
if normalize:
img_f /= 255.0
if layout == "NHWC":
out = np.expand_dims(img_f, axis=0)
else:
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
return np.ascontiguousarray(out)
# ---------------------------
# 单次 RKNN 推理(三分类)
# ---------------------------
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
outs = rknn.inference([input_tensor])
logits = outs[0].reshape(-1).astype(np.float32) # shape = (3,)
# softmax
exp = np.exp(logits - np.max(logits))
probs = exp / np.sum(exp)
class_id = int(np.argmax(probs))
return class_id, probs
# ---------------------------
# 单张图片推理(三分类)- 已移除 ROI 逻辑
# ---------------------------
def classify_single_image(
frame,
model_path,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if frame is None:
raise FileNotFoundError("❌ 输入帧为空")
# 直接使用整图,不再裁剪
input_tensor = preprocess_image_for_rknn(
frame,
size=size,
resize_mode=resize_mode,
to_rgb=to_rgb,
normalize=normalize,
layout=layout
)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
return {
"class_id": class_id,
"class": class_name,
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# ---------------------------
# 示例调用
# ---------------------------
if __name__ == "__main__":
model_path = "charge_cls.rknn"
# roi_file 已移除
image_path = "class2.png"
frame = cv2.imread(image_path)
if frame is None:
raise FileNotFoundError(f"❌ 无法读取图片:{image_path}")
# 调用
result = classify_single_image(frame, model_path)
print("[RESULT]", result)