更新charge振捣判断
This commit is contained in:
BIN
charge_3cls/charge_cls.rknn
Normal file
BIN
charge_3cls/charge_cls.rknn
Normal file
Binary file not shown.
198
charge_3cls/charge_cls_rknn.py
Normal file
198
charge_3cls/charge_cls_rknn.py
Normal file
@ -0,0 +1,198 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from rknnlite.api import RKNNLite
|
||||
from collections import deque
|
||||
|
||||
class StableClassJudge:
|
||||
"""
|
||||
连续三帧稳定判决器:
|
||||
- class0 / class1 连续 3 帧 -> 输出
|
||||
- class2 -> 清空计数,重新统计
|
||||
"""
|
||||
|
||||
def __init__(self, stable_frames=3, ignore_class=2):
|
||||
self.stable_frames = stable_frames
|
||||
self.ignore_class = ignore_class
|
||||
self.buffer = deque(maxlen=stable_frames)
|
||||
|
||||
def reset(self):
|
||||
self.buffer.clear()
|
||||
|
||||
def update(self, class_id):
|
||||
"""
|
||||
输入单帧分类结果
|
||||
返回:
|
||||
- None:尚未稳定
|
||||
- class_id:稳定输出结果
|
||||
"""
|
||||
|
||||
# 遇到 class2,直接清空重新计数
|
||||
if class_id == self.ignore_class:
|
||||
self.reset()
|
||||
return None
|
||||
|
||||
self.buffer.append(class_id)
|
||||
|
||||
# 缓冲未满
|
||||
if len(self.buffer) < self.stable_frames:
|
||||
return None
|
||||
|
||||
# 三帧完全一致
|
||||
if len(set(self.buffer)) == 1:
|
||||
stable_class = self.buffer[0]
|
||||
self.reset() # 输出一次后重新计数(防止重复触发)
|
||||
return stable_class
|
||||
|
||||
return None
|
||||
|
||||
# ---------------------------
|
||||
# 三分类映射
|
||||
# ---------------------------
|
||||
CLASS_NAMES = {
|
||||
0: "插好",
|
||||
1: "未插好",
|
||||
2: "有遮挡"
|
||||
}
|
||||
|
||||
# ---------------------------
|
||||
# RKNN 全局实例(只加载一次)
|
||||
# ---------------------------
|
||||
_global_rknn = None
|
||||
|
||||
|
||||
def init_rknn_model(model_path):
|
||||
global _global_rknn
|
||||
if _global_rknn is not None:
|
||||
return _global_rknn
|
||||
|
||||
rknn = RKNNLite(verbose=False)
|
||||
ret = rknn.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||||
|
||||
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Init runtime failed: {ret}")
|
||||
|
||||
_global_rknn = rknn
|
||||
print(f"[INFO] RKNN 模型加载成功:{model_path}")
|
||||
return rknn
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# 预处理
|
||||
# ---------------------------
|
||||
def letterbox(image, new_size=640, color=(114, 114, 114)):
|
||||
h, w = image.shape[:2]
|
||||
scale = min(new_size / h, new_size / w)
|
||||
nh, nw = int(h * scale), int(w * scale)
|
||||
resized = cv2.resize(image, (nw, nh))
|
||||
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
|
||||
top = (new_size - nh) // 2
|
||||
left = (new_size - nw) // 2
|
||||
new_img[top:top + nh, left:left + nw] = resized
|
||||
return new_img
|
||||
|
||||
|
||||
def resize_stretch(image, size=640):
|
||||
return cv2.resize(image, (size, size))
|
||||
|
||||
|
||||
def preprocess_image_for_rknn(
|
||||
img,
|
||||
size=640,
|
||||
resize_mode="stretch",
|
||||
to_rgb=True,
|
||||
normalize=False,
|
||||
layout="NHWC"
|
||||
):
|
||||
if resize_mode == "letterbox":
|
||||
img_box = letterbox(img, new_size=size)
|
||||
else:
|
||||
img_box = resize_stretch(img, size=size)
|
||||
|
||||
if to_rgb:
|
||||
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img_f = img_box.astype(np.float32)
|
||||
|
||||
if normalize:
|
||||
img_f /= 255.0
|
||||
|
||||
if layout == "NHWC":
|
||||
out = np.expand_dims(img_f, axis=0)
|
||||
else:
|
||||
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
|
||||
|
||||
return np.ascontiguousarray(out)
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# 单次 RKNN 推理(三分类)
|
||||
# ---------------------------
|
||||
def rknn_classify_preprocessed(input_tensor, model_path):
|
||||
rknn = init_rknn_model(model_path)
|
||||
|
||||
outs = rknn.inference([input_tensor])
|
||||
logits = outs[0].reshape(-1).astype(np.float32) # shape = (3,)
|
||||
|
||||
# softmax
|
||||
exp = np.exp(logits - np.max(logits))
|
||||
probs = exp / np.sum(exp)
|
||||
|
||||
class_id = int(np.argmax(probs))
|
||||
return class_id, probs
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# 单张图片推理(三分类)- 已移除 ROI 逻辑
|
||||
# ---------------------------
|
||||
def classify_single_image(
|
||||
frame,
|
||||
model_path,
|
||||
size=640,
|
||||
resize_mode="stretch",
|
||||
to_rgb=True,
|
||||
normalize=False,
|
||||
layout="NHWC"
|
||||
):
|
||||
if frame is None:
|
||||
raise FileNotFoundError("❌ 输入帧为空")
|
||||
|
||||
# 直接使用整图,不再裁剪
|
||||
input_tensor = preprocess_image_for_rknn(
|
||||
frame,
|
||||
size=size,
|
||||
resize_mode=resize_mode,
|
||||
to_rgb=to_rgb,
|
||||
normalize=normalize,
|
||||
layout=layout
|
||||
)
|
||||
|
||||
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||||
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
|
||||
|
||||
return {
|
||||
"class_id": class_id,
|
||||
"class": class_name,
|
||||
"score": round(float(probs[class_id]), 4),
|
||||
"raw": probs.tolist()
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# 示例调用
|
||||
# ---------------------------
|
||||
if __name__ == "__main__":
|
||||
model_path = "charge_cls.rknn"
|
||||
# roi_file 已移除
|
||||
image_path = "class2.png"
|
||||
|
||||
frame = cv2.imread(image_path)
|
||||
if frame is None:
|
||||
raise FileNotFoundError(f"❌ 无法读取图片:{image_path}")
|
||||
|
||||
# 调用
|
||||
result = classify_single_image(frame, model_path)
|
||||
print("[RESULT]", result)
|
||||
BIN
charge_3cls/class1.png
Normal file
BIN
charge_3cls/class1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.3 MiB |
BIN
charge_3cls/class2.png
Normal file
BIN
charge_3cls/class2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.3 MiB |
Reference in New Issue
Block a user