This commit is contained in:
2026-04-07 09:51:38 +08:00
parent ecba4d726a
commit 00dcd6b6cc
36 changed files with 2857 additions and 505 deletions

Binary file not shown.

View File

@ -0,0 +1,131 @@
import os
import cv2
import numpy as np
from collections import deque
class StableClassJudge:
"""
连续三帧稳定判决器:
- class0 / class1 连续 3 帧 -> 输出
- class2 -> 清空计数,重新统计
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
if len(self.buffer) < self.stable_frames:
return None
if len(set(self.buffer)) == 1:
stable_class = self.buffer[0]
self.reset()
return stable_class
return None
# ---------------------------
# 三分类映射
# ---------------------------
CLASS_NAMES = {
0: "插好",
1: "未插好",
2: "有遮挡"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
def init_rknn_model(model_path):
from rknnlite.api import RKNNLite
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功:{model_path}")
return rknn
# ---------------------------
# 预处理(输入 uint8RKNN 内部转 float32
# ---------------------------
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(img, size=640):
# 输入必须是 uint8 [0,255],即使模型是 float32
img_resized = resize_stretch(img, size=size)
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
input_tensor = np.expand_dims(img_rgb, axis=0).astype(np.uint8) # NHWC, uint8
return np.ascontiguousarray(input_tensor)
# ---------------------------
# 单次 RKNN 推理三分类float32 模型)
# ---------------------------
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
outs = rknn.inference([input_tensor])
# 直接得到 logits
probs = outs[0].flatten().astype(np.float32) # shape: (3,)
class_id = int(np.argmax(probs))
return class_id, probs
# ---------------------------
# 单张图片推理
# ---------------------------
def classify_single_image(frame, model_path, size=640):
if frame is None:
raise FileNotFoundError("输入帧为空")
input_tensor = preprocess_image_for_rknn(frame, size=size)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
return {
"class_id": class_id,
"class": class_name,
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# ---------------------------
# 示例调用
# ---------------------------
if __name__ == "__main__":
model_path = "charge0324.rknn"
image_path = "class2.jpg"
frame = cv2.imread(image_path)
if frame is None:
raise FileNotFoundError(f"无法读取图片:{image_path}")
result = classify_single_image(frame, model_path)
print("[RESULT]", result)

View File

@ -0,0 +1,88 @@
import os
import cv2
#from rknnlite.api import RKNNLite
import time
# classify_single_image, StableClassJudge, CLASS_NAMES 已在 muju_cls_rknn 中定义
from .charge_cls_rknn import classify_single_image, StableClassJudge, CLASS_NAMES
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
def run_stable_charge_loop():
"""
image_source: cv2.VideoCapture 对象
"""
_ret=None
# 使用相对于当前文件的绝对路径
model_path = os.path.join(current_dir, "charge0324.rknn")
# roi_file = os.path.join(current_dir, "roi_coordinates", "muju_roi.txt")
RTSP_URL = "rtsp://admin:XJ123456@192.168.250.60:554/streaming/channels/101"
stable_frames=5
print(f"正在连接 RTSP 流: {RTSP_URL}")
cap =None
try:
cap = cv2.VideoCapture(RTSP_URL)
# 降低 RTSP 延迟(部分摄像头支持)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
if not cap.isOpened():
print("无法打开 RTSP 流,请检查网络、账号密码或 URL")
return None
print("RTSP 流连接成功,开始推理...")
judge = StableClassJudge(
stable_frames=stable_frames,
ignore_class=2 # 忽略“有遮挡”类别参与稳定判断
)
if not hasattr(cap, "read"):
raise TypeError("image_source 必须是 cv2.VideoCapture 实例")
_max_count=10
while True:
_max_count=_max_count-1
ret, frame = cap.read()
if not ret:
print("无法读取视频帧(可能是流断开或结束)")
continue
# 上下左右翻转
# frame = cv2.flip(frame, -1)
# ---------------------------
# 单帧推理
# ---------------------------
result = classify_single_image(frame, model_path)
class_id = result["class_id"]
class_name = result["class"]
score = result["score"]
print(f"[FRAME] {class_name} | conf={score:.3f}")
if score>0.8:
# ---------------------------
# 稳定判断
# ---------------------------
stable_class_id = judge.update(class_id)
if stable_class_id is not None:
_ret=CLASS_NAMES[stable_class_id]
if _ret is None:
print("-------当前振捣棒检测为空,继续等待稳定------")
continue
if _ret=="插好":
break
print(f"-------当前振捣棒检测为:{_ret},继续等待稳定------")
else:
print("-------当前振捣棒检测为空,继续等待稳定------")
time.sleep(0.1)
finally:
if cap is not None:
cap.release()
return _ret

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB