Files
zjsh_code_jicheng/muju_cls/muju_cls_rknn.py
2025-12-28 00:14:08 +08:00

283 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from rknnlite.api import RKNNLite
from collections import deque
class StableClassJudge:
"""
连续三帧稳定判决器:
- class0 / class1 连续 3 帧 -> 输出
- class2 -> 清空计数,重新统计
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
"""
输入单帧分类结果
返回:
- None尚未稳定
- class_id稳定输出结果
"""
# 遇到 class2直接清空重新计数
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
# 缓冲未满
if len(self.buffer) < self.stable_frames:
return None
# 三帧完全一致
if len(set(self.buffer)) == 1:
stable_class = self.buffer[0]
self.reset() # 输出一次后重新计数(防止重复触发)
return stable_class
return None
# ---------------------------
# 三分类映射模具车1是小的模具车2是大的
# ---------------------------
CLASS_NAMES = {
0: "模具车1",
1: "模具车2",
2: "有遮挡"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
def init_rknn_model(model_path):
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功: {model_path}")
return rknn
# ---------------------------
# 预处理
# ---------------------------
def letterbox(image, new_size=640, color=(114, 114, 114)):
h, w = image.shape[:2]
scale = min(new_size / h, new_size / w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
top = (new_size - nh) // 2
left = (new_size - nw) // 2
new_img[top:top + nh, left:left + nw] = resized
return new_img
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(
img,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if resize_mode == "letterbox":
img_box = letterbox(img, new_size=size)
else:
img_box = resize_stretch(img, size=size)
if to_rgb:
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
img_f = img_box.astype(np.float32)
if normalize:
img_f /= 255.0
if layout == "NHWC":
out = np.expand_dims(img_f, axis=0)
else:
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
return np.ascontiguousarray(out)
# ---------------------------
# 单次 RKNN 推理(三分类)
# ---------------------------
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
input_tensor = np.ascontiguousarray(input_tensor.astype(np.float32))
outs = rknn.inference([input_tensor])
pred = outs[0].reshape(-1).astype(float) # shape = (3,)
class_id = int(np.argmax(pred))
return class_id, pred
# ---------------------------
# ROI
# ---------------------------
def load_single_roi(txt_path):
if not os.path.exists(txt_path):
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
with open(txt_path) as f:
for line in f:
s = line.strip()
if not s:
continue
x, y, w, h = map(int, s.split(','))
return (x, y, w, h)
raise RuntimeError("ROI 文件为空")
def crop_and_return_roi(img, roi):
x, y, w, h = roi
h_img, w_img = img.shape[:2]
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
raise RuntimeError(f"ROI 超出图像范围: {roi}")
return img[y:y + h, x:x + w]
# ---------------------------
# 单张图片推理(三分类)
# ---------------------------
def classify_single_image(
frame,
model_path,
roi_file,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if frame is None:
raise FileNotFoundError("输入帧为空")
roi = load_single_roi(roi_file)
roi_img = crop_and_return_roi(frame, roi)
input_tensor = preprocess_image_for_rknn(
roi_img,
size=size,
resize_mode=resize_mode,
to_rgb=to_rgb,
normalize=normalize,
layout=layout
)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
return {
"class_id": class_id,
"class": class_name,
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# ---------------------------
# 示例调用
# ---------------------------
if __name__ == "__main__":
model_path = "muju_cls.rknn"
roi_file = "./roi_coordinates/muju_roi.txt"
image_path = "./test_image/test.png"
frame = cv2.imread(image_path)
if frame is None:
raise FileNotFoundError(f"无法读取图片: {image_path}")
result = classify_single_image(frame, model_path, roi_file)
print("[RESULT]", result)
# ---------------------------
# 示例判断逻辑
'''
import cv2
from muju_cls_rknn import classify_single_image,StableClassJudge,CLASS_NAMES
def run_stable_classification_loop(
model_path,
roi_file,
image_source,
stable_frames=3
):
"""
image_source:
- cv2.VideoCapture
"""
judge = StableClassJudge(
stable_frames=stable_frames,
ignore_class=2 # 有遮挡
)
cap = image_source
if not hasattr(cap, "read"):
raise TypeError("image_source 必须是 cv2.VideoCapture")
while True:
ret, frame = cap.read()
# 上下左右翻转
frame = cv2.flip(frame, -1)
if not ret:
print("读取帧失败,退出")
break
result = classify_single_image(frame, model_path, roi_file)
class_id = result["class_id"]
class_name = result["class"]
score = result["score"]
print(f"[FRAME] {class_name} conf={score}")
stable = judge.update(class_id)
if stable is not None:
print(f"\n稳定输出: {CLASS_NAMES[stable]} \n")
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
'''
# ---------------------------