283 lines
6.9 KiB
Python
283 lines
6.9 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
from rknnlite.api import RKNNLite
|
||
|
||
from collections import deque
|
||
|
||
class StableClassJudge:
|
||
"""
|
||
连续三帧稳定判决器:
|
||
- class0 / class1 连续 3 帧 -> 输出
|
||
- class2 -> 清空计数,重新统计
|
||
"""
|
||
|
||
def __init__(self, stable_frames=3, ignore_class=2):
|
||
self.stable_frames = stable_frames
|
||
self.ignore_class = ignore_class
|
||
self.buffer = deque(maxlen=stable_frames)
|
||
|
||
def reset(self):
|
||
self.buffer.clear()
|
||
|
||
def update(self, class_id):
|
||
"""
|
||
输入单帧分类结果
|
||
返回:
|
||
- None:尚未稳定
|
||
- class_id:稳定输出结果
|
||
"""
|
||
|
||
# 遇到 class2,直接清空重新计数
|
||
if class_id == self.ignore_class:
|
||
self.reset()
|
||
return None
|
||
|
||
self.buffer.append(class_id)
|
||
|
||
# 缓冲未满
|
||
if len(self.buffer) < self.stable_frames:
|
||
return None
|
||
|
||
# 三帧完全一致
|
||
if len(set(self.buffer)) == 1:
|
||
stable_class = self.buffer[0]
|
||
self.reset() # 输出一次后重新计数(防止重复触发)
|
||
return stable_class
|
||
|
||
return None
|
||
|
||
# ---------------------------
|
||
# 三分类映射,模具车1是小的,模具车2是大的
|
||
# ---------------------------
|
||
CLASS_NAMES = {
|
||
0: "模具车1",
|
||
1: "模具车2",
|
||
2: "有遮挡"
|
||
}
|
||
|
||
# ---------------------------
|
||
# RKNN 全局实例(只加载一次)
|
||
# ---------------------------
|
||
_global_rknn = None
|
||
|
||
|
||
def init_rknn_model(model_path):
|
||
global _global_rknn
|
||
if _global_rknn is not None:
|
||
return _global_rknn
|
||
|
||
rknn = RKNNLite(verbose=False)
|
||
ret = rknn.load_rknn(model_path)
|
||
if ret != 0:
|
||
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||
|
||
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||
if ret != 0:
|
||
raise RuntimeError(f"Init runtime failed: {ret}")
|
||
|
||
_global_rknn = rknn
|
||
print(f"[INFO] RKNN 模型加载成功: {model_path}")
|
||
return rknn
|
||
|
||
|
||
# ---------------------------
|
||
# 预处理
|
||
# ---------------------------
|
||
def letterbox(image, new_size=640, color=(114, 114, 114)):
|
||
h, w = image.shape[:2]
|
||
scale = min(new_size / h, new_size / w)
|
||
nh, nw = int(h * scale), int(w * scale)
|
||
resized = cv2.resize(image, (nw, nh))
|
||
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
|
||
top = (new_size - nh) // 2
|
||
left = (new_size - nw) // 2
|
||
new_img[top:top + nh, left:left + nw] = resized
|
||
return new_img
|
||
|
||
|
||
def resize_stretch(image, size=640):
|
||
return cv2.resize(image, (size, size))
|
||
|
||
|
||
def preprocess_image_for_rknn(
|
||
img,
|
||
size=640,
|
||
resize_mode="stretch",
|
||
to_rgb=True,
|
||
normalize=False,
|
||
layout="NHWC"
|
||
):
|
||
if resize_mode == "letterbox":
|
||
img_box = letterbox(img, new_size=size)
|
||
else:
|
||
img_box = resize_stretch(img, size=size)
|
||
|
||
if to_rgb:
|
||
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
|
||
|
||
img_f = img_box.astype(np.float32)
|
||
|
||
if normalize:
|
||
img_f /= 255.0
|
||
|
||
if layout == "NHWC":
|
||
out = np.expand_dims(img_f, axis=0)
|
||
else:
|
||
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
|
||
|
||
return np.ascontiguousarray(out)
|
||
|
||
|
||
# ---------------------------
|
||
# 单次 RKNN 推理(三分类)
|
||
# ---------------------------
|
||
def rknn_classify_preprocessed(input_tensor, model_path):
|
||
rknn = init_rknn_model(model_path)
|
||
|
||
input_tensor = np.ascontiguousarray(input_tensor.astype(np.float32))
|
||
outs = rknn.inference([input_tensor])
|
||
|
||
pred = outs[0].reshape(-1).astype(float) # shape = (3,)
|
||
class_id = int(np.argmax(pred))
|
||
|
||
return class_id, pred
|
||
|
||
# ---------------------------
|
||
# ROI
|
||
# ---------------------------
|
||
def load_single_roi(txt_path):
|
||
if not os.path.exists(txt_path):
|
||
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
|
||
|
||
with open(txt_path) as f:
|
||
for line in f:
|
||
s = line.strip()
|
||
if not s:
|
||
continue
|
||
x, y, w, h = map(int, s.split(','))
|
||
return (x, y, w, h)
|
||
|
||
raise RuntimeError("ROI 文件为空")
|
||
|
||
|
||
def crop_and_return_roi(img, roi):
|
||
x, y, w, h = roi
|
||
h_img, w_img = img.shape[:2]
|
||
|
||
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
|
||
raise RuntimeError(f"ROI 超出图像范围: {roi}")
|
||
|
||
return img[y:y + h, x:x + w]
|
||
|
||
|
||
# ---------------------------
|
||
# 单张图片推理(三分类)
|
||
# ---------------------------
|
||
def classify_single_image(
|
||
frame,
|
||
model_path,
|
||
roi_file,
|
||
size=640,
|
||
resize_mode="stretch",
|
||
to_rgb=True,
|
||
normalize=False,
|
||
layout="NHWC"
|
||
):
|
||
if frame is None:
|
||
raise FileNotFoundError("输入帧为空")
|
||
|
||
roi = load_single_roi(roi_file)
|
||
roi_img = crop_and_return_roi(frame, roi)
|
||
|
||
input_tensor = preprocess_image_for_rknn(
|
||
roi_img,
|
||
size=size,
|
||
resize_mode=resize_mode,
|
||
to_rgb=to_rgb,
|
||
normalize=normalize,
|
||
layout=layout
|
||
)
|
||
|
||
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
|
||
|
||
return {
|
||
"class_id": class_id,
|
||
"class": class_name,
|
||
"score": round(float(probs[class_id]), 4),
|
||
"raw": probs.tolist()
|
||
}
|
||
|
||
|
||
|
||
# ---------------------------
|
||
# 示例调用
|
||
# ---------------------------
|
||
if __name__ == "__main__":
|
||
model_path = "muju_cls.rknn"
|
||
roi_file = "./roi_coordinates/muju_roi.txt"
|
||
image_path = "./test_image/test.png"
|
||
|
||
frame = cv2.imread(image_path)
|
||
if frame is None:
|
||
raise FileNotFoundError(f"无法读取图片: {image_path}")
|
||
|
||
result = classify_single_image(frame, model_path, roi_file)
|
||
print("[RESULT]", result)
|
||
|
||
# ---------------------------
|
||
# 示例判断逻辑
|
||
'''
|
||
import cv2
|
||
from muju_cls_rknn import classify_single_image,StableClassJudge,CLASS_NAMES
|
||
|
||
def run_stable_classification_loop(
|
||
model_path,
|
||
roi_file,
|
||
image_source,
|
||
stable_frames=3
|
||
):
|
||
"""
|
||
image_source:
|
||
- cv2.VideoCapture
|
||
"""
|
||
judge = StableClassJudge(
|
||
stable_frames=stable_frames,
|
||
ignore_class=2 # 有遮挡
|
||
)
|
||
|
||
cap = image_source
|
||
if not hasattr(cap, "read"):
|
||
raise TypeError("image_source 必须是 cv2.VideoCapture")
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
# 上下左右翻转
|
||
frame = cv2.flip(frame, -1)
|
||
|
||
if not ret:
|
||
print("读取帧失败,退出")
|
||
break
|
||
|
||
result = classify_single_image(frame, model_path, roi_file)
|
||
|
||
class_id = result["class_id"]
|
||
class_name = result["class"]
|
||
score = result["score"]
|
||
|
||
print(f"[FRAME] {class_name} conf={score}")
|
||
|
||
stable = judge.update(class_id)
|
||
|
||
if stable is not None:
|
||
print(f"\n稳定输出: {CLASS_NAMES[stable]} \n")
|
||
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
break
|
||
|
||
cap.release()
|
||
cv2.destroyAllWindows()
|
||
'''
|
||
# ---------------------------
|