132 lines
3.5 KiB
Python
132 lines
3.5 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
|
||
from collections import deque
|
||
|
||
class StableClassJudge:
|
||
"""
|
||
连续三帧稳定判决器:
|
||
- class0 / class1 连续 3 帧 -> 输出
|
||
- class2 -> 清空计数,重新统计
|
||
"""
|
||
|
||
def __init__(self, stable_frames=3, ignore_class=2):
|
||
self.stable_frames = stable_frames
|
||
self.ignore_class = ignore_class
|
||
self.buffer = deque(maxlen=stable_frames)
|
||
|
||
def reset(self):
|
||
self.buffer.clear()
|
||
|
||
def update(self, class_id):
|
||
if class_id == self.ignore_class:
|
||
self.reset()
|
||
return None
|
||
|
||
self.buffer.append(class_id)
|
||
|
||
if len(self.buffer) < self.stable_frames:
|
||
return None
|
||
|
||
if len(set(self.buffer)) == 1:
|
||
stable_class = self.buffer[0]
|
||
self.reset()
|
||
return stable_class
|
||
|
||
return None
|
||
|
||
# ---------------------------
|
||
# 三分类映射
|
||
# ---------------------------
|
||
CLASS_NAMES = {
|
||
0: "插好",
|
||
1: "未插好",
|
||
2: "有遮挡"
|
||
}
|
||
|
||
# ---------------------------
|
||
# RKNN 全局实例(只加载一次)
|
||
# ---------------------------
|
||
_global_rknn = None
|
||
|
||
def init_rknn_model(model_path):
|
||
from rknnlite.api import RKNNLite
|
||
global _global_rknn
|
||
if _global_rknn is not None:
|
||
return _global_rknn
|
||
|
||
rknn = RKNNLite(verbose=False)
|
||
ret = rknn.load_rknn(model_path)
|
||
if ret != 0:
|
||
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||
|
||
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||
if ret != 0:
|
||
raise RuntimeError(f"Init runtime failed: {ret}")
|
||
|
||
_global_rknn = rknn
|
||
print(f"[INFO] RKNN 模型加载成功:{model_path}")
|
||
return rknn
|
||
|
||
|
||
# ---------------------------
|
||
# 预处理(输入 uint8,RKNN 内部转 float32)
|
||
# ---------------------------
|
||
def resize_stretch(image, size=640):
|
||
return cv2.resize(image, (size, size))
|
||
|
||
def preprocess_image_for_rknn(img, size=640):
|
||
# 输入必须是 uint8 [0,255],即使模型是 float32
|
||
img_resized = resize_stretch(img, size=size)
|
||
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
||
input_tensor = np.expand_dims(img_rgb, axis=0).astype(np.uint8) # NHWC, uint8
|
||
return np.ascontiguousarray(input_tensor)
|
||
|
||
|
||
# ---------------------------
|
||
# 单次 RKNN 推理(三分类,float32 模型)
|
||
# ---------------------------
|
||
def rknn_classify_preprocessed(input_tensor, model_path):
|
||
rknn = init_rknn_model(model_path)
|
||
outs = rknn.inference([input_tensor])
|
||
|
||
# 直接得到 logits
|
||
probs = outs[0].flatten().astype(np.float32) # shape: (3,)
|
||
class_id = int(np.argmax(probs))
|
||
return class_id, probs
|
||
|
||
|
||
# ---------------------------
|
||
# 单张图片推理
|
||
# ---------------------------
|
||
def classify_single_image(frame, model_path, size=640):
|
||
if frame is None:
|
||
raise FileNotFoundError("输入帧为空")
|
||
|
||
input_tensor = preprocess_image_for_rknn(frame, size=size)
|
||
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
|
||
|
||
return {
|
||
"class_id": class_id,
|
||
"class": class_name,
|
||
"score": round(float(probs[class_id]), 4),
|
||
"raw": probs.tolist()
|
||
}
|
||
|
||
|
||
# ---------------------------
|
||
# 示例调用
|
||
# ---------------------------
|
||
if __name__ == "__main__":
|
||
model_path = "charge0324.rknn"
|
||
image_path = "class2.jpg"
|
||
|
||
frame = cv2.imread(image_path)
|
||
if frame is None:
|
||
raise FileNotFoundError(f"无法读取图片:{image_path}")
|
||
|
||
result = classify_single_image(frame, model_path)
|
||
print("[RESULT]", result)
|