更新charge振捣判断

This commit is contained in:
琉璃月光
2026-03-10 16:51:57 +08:00
parent 235101b4d8
commit 5d79686ba0
15 changed files with 268873 additions and 0 deletions

25
1.txt Normal file
View File

@ -0,0 +1,25 @@
(zjsh) teamhd@teamhd:~/Downloads/rkmpp/build/linux/aarch64$ sudo apt install gcc-aarch64-linux-gnu
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Note, selecting 'gcc' instead of 'gcc-aarch64-linux-gnu'
Some packages could not be installed. This may mean that you have
requested an impossible situation or if you are using the unstable
distribution that some required packages have not yet been created
or been moved out of Incoming.
The following information may help to resolve the situation:
The following packages have unmet dependencies:
gcc-11 : Depends: cpp-11 (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
libasan6 : Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
libcc1-0 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
libgcc-11-dev : Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
libhwasan0 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
libitm1 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
liblsan0 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
libtsan0 : Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
libubsan1 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
E: Unable to correct problems, you have held broken packages.
(zjsh) teamhd@teamhd:~/Downloads/rkmpp/build/linux/aarch64$

BIN
charge_3cls/charge_cls.rknn Normal file

Binary file not shown.

View File

@ -0,0 +1,198 @@
import os
import cv2
import numpy as np
from rknnlite.api import RKNNLite
from collections import deque
class StableClassJudge:
"""
连续三帧稳定判决器:
- class0 / class1 连续 3 帧 -> 输出
- class2 -> 清空计数,重新统计
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
"""
输入单帧分类结果
返回:
- None尚未稳定
- class_id稳定输出结果
"""
# 遇到 class2直接清空重新计数
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
# 缓冲未满
if len(self.buffer) < self.stable_frames:
return None
# 三帧完全一致
if len(set(self.buffer)) == 1:
stable_class = self.buffer[0]
self.reset() # 输出一次后重新计数(防止重复触发)
return stable_class
return None
# ---------------------------
# 三分类映射
# ---------------------------
CLASS_NAMES = {
0: "插好",
1: "未插好",
2: "有遮挡"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
def init_rknn_model(model_path):
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功:{model_path}")
return rknn
# ---------------------------
# 预处理
# ---------------------------
def letterbox(image, new_size=640, color=(114, 114, 114)):
h, w = image.shape[:2]
scale = min(new_size / h, new_size / w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
top = (new_size - nh) // 2
left = (new_size - nw) // 2
new_img[top:top + nh, left:left + nw] = resized
return new_img
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(
img,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if resize_mode == "letterbox":
img_box = letterbox(img, new_size=size)
else:
img_box = resize_stretch(img, size=size)
if to_rgb:
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
img_f = img_box.astype(np.float32)
if normalize:
img_f /= 255.0
if layout == "NHWC":
out = np.expand_dims(img_f, axis=0)
else:
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
return np.ascontiguousarray(out)
# ---------------------------
# 单次 RKNN 推理(三分类)
# ---------------------------
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
outs = rknn.inference([input_tensor])
logits = outs[0].reshape(-1).astype(np.float32) # shape = (3,)
# softmax
exp = np.exp(logits - np.max(logits))
probs = exp / np.sum(exp)
class_id = int(np.argmax(probs))
return class_id, probs
# ---------------------------
# 单张图片推理(三分类)- 已移除 ROI 逻辑
# ---------------------------
def classify_single_image(
frame,
model_path,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if frame is None:
raise FileNotFoundError("❌ 输入帧为空")
# 直接使用整图,不再裁剪
input_tensor = preprocess_image_for_rknn(
frame,
size=size,
resize_mode=resize_mode,
to_rgb=to_rgb,
normalize=normalize,
layout=layout
)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
return {
"class_id": class_id,
"class": class_name,
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# ---------------------------
# 示例调用
# ---------------------------
if __name__ == "__main__":
model_path = "charge_cls.rknn"
# roi_file 已移除
image_path = "class2.png"
frame = cv2.imread(image_path)
if frame is None:
raise FileNotFoundError(f"❌ 无法读取图片:{image_path}")
# 调用
result = classify_single_image(frame, model_path)
print("[RESULT]", result)

BIN
charge_3cls/class1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

BIN
charge_3cls/class2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

170
image_01.py Normal file
View File

@ -0,0 +1,170 @@
import cv2
import time
import os
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
# ================== 配置区域 ==================
RTSP_URL = "rtsp://admin:XJ123456@192.168.10.50:554/streaming/channels/101"
SAVE_INTERVAL = 10 # 每 20 帧尝试一次处理
DISPLAY_STREAM = False # 是否显示画面
# --- 灰图过滤配置 ---
GRAY_LOWER = 70
GRAY_UPPER = 230
GRAY_RATIO_THRESHOLD = 0.7 # 灰色像素占比超过此值视为灰图
# --- SSIM 去重配置 ---
SSIM_THRESHOLD = 0.9 # 相似度超过此值视为重复图片
# --- 保存目录 ---
OUTPUT_DIR = "camera02_save"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ================== 辅助函数 ==================
def is_large_gray(image):
"""
检测图像是否为大面积灰图
image: PIL Image 对象 (RGB)
"""
arr = np.array(image)
# 检查 R, G, B 三个通道是否都在灰色范围内
gray_mask = (
(arr[:, :, 0] >= GRAY_LOWER) & (arr[:, :, 0] <= GRAY_UPPER) &
(arr[:, :, 1] >= GRAY_LOWER) & (arr[:, :, 1] <= GRAY_UPPER) &
(arr[:, :, 2] >= GRAY_LOWER) & (arr[:, :, 2] <= GRAY_UPPER)
)
# 计算灰色像素占比
return np.mean(gray_mask) > GRAY_RATIO_THRESHOLD
def open_camera(url):
cap = cv2.VideoCapture(url)
if not cap.isOpened():
return None
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 5000)
cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 5000)
return cap
def get_valid_frame(cap):
"""安全读取一帧"""
if cap is None or not cap.isOpened():
return None, False
ret, frame = cap.read()
if not ret or frame is None or frame.size == 0:
return None, False
return frame, True
# ================== 主程序 ==================
print(f"✅ 正在连接摄像头: {RTSP_URL} ...")
cap = open_camera(RTSP_URL)
if cap is None:
print(f"❌ 连接失败!请检查 IP、账号密码或网络。")
exit(1)
print("📡 摄像头已连接,开始采集...")
print(f" - 保存目录: {os.path.abspath(OUTPUT_DIR)}")
print(f" - 灰度阈值: [{GRAY_LOWER}, {GRAY_UPPER}], 占比 > {GRAY_RATIO_THRESHOLD}")
print(f" - SSIM 去重阈值: > {SSIM_THRESHOLD}")
print(" - 按 Ctrl+C 停止\n")
frame_count = 0
saved_count = 0
last_gray_frame = None # 用于存储上一帧的灰度图以计算 SSIM
try:
while True:
frame, ret = get_valid_frame(cap)
if not ret:
print("⚠️ 读取帧失败,尝试重连...")
time.sleep(2)
cap.release()
cap = open_camera(RTSP_URL)
if cap is None:
print("❌ 重连失败,退出程序。")
break
last_gray_frame = None # 重连后重置对比帧
continue
frame_count += 1
# 间隔采样:不是指定间隔的帧直接跳过(可选显示)
if frame_count % SAVE_INTERVAL != 0:
if DISPLAY_STREAM:
cv2.imshow("Camera Stream", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
continue
# --- 步骤 1: 灰图检测 ---
try:
# 转 RGB 供 PIL 使用
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
if is_large_gray(pil_img):
print(f"⏭️ 跳过:检测到灰图 (帧 {frame_count})")
continue
except Exception as e:
print(f"⚠️ 灰图检测异常: {e}")
continue
# --- 步骤 2: SSIM 相似性去重 ---
try:
gray_curr = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if last_gray_frame is not None:
# 计算结构相似性
# resize 确保尺寸一致(防止分辨率动态变化导致报错)
if last_gray_frame.shape != gray_curr.shape:
last_gray_frame = cv2.resize(last_gray_frame, (gray_curr.shape[1], gray_curr.shape[0]))
sim_score = ssim(last_gray_frame, gray_curr)
if sim_score > SSIM_THRESHOLD:
print(f"⏭️ 跳过:画面重复 (SSIM={sim_score:.3f}, 帧 {frame_count})")
continue
# 更新参考帧
last_gray_frame = gray_curr.copy()
except Exception as e:
print(f"⚠️ SSIM 计算异常: {e}")
# 如果计算出错,可以选择跳过或强制保存,这里选择跳过以防崩溃
continue
# --- 步骤 3: 保存图片 ---
ts = time.strftime("%Y%m%d_%H%M%S")
ms = int((time.time() % 1) * 1000)
filename = f"img_{ts}_{ms:03d}.png"
save_path = os.path.join(OUTPUT_DIR, filename)
# 如果需要旋转/翻转,在这里操作 (例如翻转 180 度)
# frame = cv2.flip(frame, -1)
cv2.imwrite(save_path, frame)
saved_count += 1
print(f"✅ [{saved_count}] 已保存有效图片: {filename}")
if DISPLAY_STREAM:
cv2.putText(frame, f"Saved: {filename}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
cv2.imshow("Camera Stream", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
except KeyboardInterrupt:
print("\n🛑 用户手动停止")
finally:
if cap:
cap.release()
cv2.destroyAllWindows()
print(f"🔚 程序结束。共保存 {saved_count} 张图片。")

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@ -75,10 +75,12 @@ def largest_intersect_cc(mask_bin, bbox):
# RANSAC 直线拟合(核心新增) # RANSAC 直线拟合(核心新增)
# --------------------------- # ---------------------------
def fit_line_ransac(pts, max_dist=2.5, min_inliers_ratio=0.6, iters=100): def fit_line_ransac(pts, max_dist=2.5, min_inliers_ratio=0.6, iters=100):
""" """
拟合 x = m*y + b 拟合 x = m*y + b
pts: Nx2 -> [x,y] pts: Nx2 -> [x,y]
""" """
np.random.seed(42)
if len(pts) < 10: if len(pts) < 10:
return None return None

117
zdb_cls/main.py Normal file
View File

@ -0,0 +1,117 @@
import os
import cv2
from rknnlite.api import RKNNLite
# classify_single_image, StableClassJudge, CLASS_NAMES 已在 muju_cls_rknn 中定义
from zdb_cls_rknn import classify_single_image, StableClassJudge, CLASS_NAMES
def run_stable_classification_loop(
model_path,
roi_file,
image_source,
stable_frames=3,
display_scale=0.5, # 显示缩放比例0.5 = 显示为原来 50%
show_window=False # 是否显示窗口
):
"""
image_source: cv2.VideoCapture 对象
"""
judge = StableClassJudge(
stable_frames=stable_frames,
ignore_class=2 # 忽略“有遮挡”类别参与稳定判断
)
cap = image_source
if not hasattr(cap, "read"):
raise TypeError("image_source 必须是 cv2.VideoCapture 实例")
# 可选:创建可缩放窗口
if show_window:
cv2.namedWindow("RTSP Stream - Press 'q' to quit", cv2.WINDOW_NORMAL)
while True:
ret, frame = cap.read()
if not ret:
print("无法读取视频帧(可能是流断开或结束)")
break
# ---------------------------
# 单帧推理
# ---------------------------
result = classify_single_image(frame, model_path, roi_file)
class_id = result["class_id"]
class_name = result["class"]
score = result["score"]
print(f"[FRAME] {class_name} | conf={score:.3f}")
# ---------------------------
# 稳定判断
# ---------------------------
stable_class_id = judge.update(class_id)
if stable_class_id is not None:
print(f"\n稳定输出: {CLASS_NAMES[stable_class_id]}\n")
# ---------------------------
# 显示画面(缩小窗口)
# ---------------------------
if show_window:
h, w = frame.shape[:2]
display_frame = cv2.resize(
frame,
(int(w * display_scale), int(h * display_scale)),
interpolation=cv2.INTER_AREA
)
cv2.imshow("RTSP Stream - Press 'q' to quit", display_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
# ---------------------------
# 配置参数
# ---------------------------
MODEL_PATH = "zdb_cls.rknn"
ROI_FILE = "./roi_coordinates/zdb_roi.txt"
RTSP_URL = "rtsp://admin:XJ123456@192.168.250.60:554/streaming/channels/101"
STABLE_FRAMES = 3
DISPLAY_SCALE = 0.5 # 显示窗口缩放比例
SHOW_WINDOW = False # 部署时改成 False测试的时候打开
# ---------------------------
# 打开 RTSP 视频流
# ---------------------------
print(f"正在连接 RTSP 流: {RTSP_URL}")
cap = cv2.VideoCapture(RTSP_URL)
# 降低 RTSP 延迟(部分摄像头支持)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
if not cap.isOpened():
print("无法打开 RTSP 流,请检查网络、账号密码或 URL")
exit(1)
print("RTSP 流连接成功,开始推理...")
# ---------------------------
# 启动稳定分类循环三帧稳定判断
# ---------------------------
run_stable_classification_loop(
model_path=MODEL_PATH,
roi_file=ROI_FILE,
image_source=cap,
stable_frames=STABLE_FRAMES,
display_scale=DISPLAY_SCALE,
show_window=SHOW_WINDOW
)

View File

@ -0,0 +1 @@
2,880,385,200

BIN
zdb_cls/test.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

275
zdb_cls/test_imagesave.py Normal file
View File

@ -0,0 +1,275 @@
import os
import cv2
import time
import numpy as np
from datetime import datetime
from collections import deque
from rknnlite.api import RKNNLite
# =====================================================
# 稳定判决器
# =====================================================
class StableClassJudge:
"""
连续 N 帧稳定判决:
- class0 / class1 连续 N 帧 -> 输出
- class2 -> 清空计数
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
if len(self.buffer) < self.stable_frames:
return None
if len(set(self.buffer)) == 1:
stable = self.buffer[0]
self.reset()
return stable
return None
# =====================================================
# 类别定义
# =====================================================
CLASS_NAMES = {
0: "未插入振捣棒",
1: "安全插入振捣棒",
2: "有遮挡"
}
# =====================================================
# RKNN 全局实例
# =====================================================
_global_rknn = None
def init_rknn_model(model_path):
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功: {model_path}")
return rknn
# =====================================================
# 图像预处理
# =====================================================
def letterbox(image, new_size=640, color=(114, 114, 114)):
h, w = image.shape[:2]
scale = min(new_size / h, new_size / w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
canvas = np.full((new_size, new_size, 3), color, dtype=np.uint8)
top = (new_size - nh) // 2
left = (new_size - nw) // 2
canvas[top:top + nh, left:left + nw] = resized
return canvas
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(
img,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if resize_mode == "letterbox":
img = letterbox(img, size)
else:
img = resize_stretch(img, size)
if to_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
if normalize:
img /= 255.0
if layout == "NHWC":
img = np.expand_dims(img, axis=0)
else:
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
return np.ascontiguousarray(img)
# =====================================================
# RKNN 单次推理
# =====================================================
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
outs = rknn.inference([input_tensor])
probs = outs[0].reshape(-1).astype(float)
class_id = int(np.argmax(probs))
return class_id, probs
# =====================================================
# ROI 处理
# =====================================================
def load_single_roi(txt_path):
if not os.path.exists(txt_path):
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
with open(txt_path) as f:
for line in f:
line = line.strip()
if not line:
continue
x, y, w, h = map(int, line.split(","))
return (x, y, w, h)
raise RuntimeError("ROI 文件为空")
def crop_and_return_roi(img, roi):
x, y, w, h = roi
H, W = img.shape[:2]
if x < 0 or y < 0 or x + w > W or y + h > H:
raise RuntimeError(f"ROI 超出图像范围: {roi}")
return img[y:y + h, x:x + w]
# =====================================================
# 单帧分类
# =====================================================
def classify_single_image(frame, model_path, roi_file):
roi = load_single_roi(roi_file)
roi_img = crop_and_return_roi(frame, roi)
input_tensor = preprocess_image_for_rknn(
roi_img,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
return {
"class_id": class_id,
"class": CLASS_NAMES[class_id],
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# =====================================================
# RTSP 推理 + 保存分类结果
# =====================================================
def run_rtsp_classification_and_save(
model_path,
roi_file,
rtsp_url,
save_root="clsimg",
stable_frames=3,
save_mode="all" # all / stable
):
for cid in CLASS_NAMES.keys():
os.makedirs(os.path.join(save_root, f"class{cid}"), exist_ok=True)
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
raise RuntimeError(f"无法打开 RTSP: {rtsp_url}")
judge = StableClassJudge(stable_frames=stable_frames, ignore_class=2)
print("[INFO] RTSP 推理开始")
while True:
ret, frame = cap.read()
if not ret:
print("[WARN] RTSP 读帧失败")
time.sleep(0.1)
continue
frame = cv2.flip(frame, -1)
result = classify_single_image(frame, model_path, roi_file)
class_id = result["class_id"]
score = result["score"]
print(f"[FRAME] {result['class']} conf={score}")
stable = judge.update(class_id)
save_flag = False
save_class = class_id
if save_mode == "all":
save_flag = True
elif save_mode == "stable" and stable is not None:
save_flag = True
save_class = stable
if save_flag:
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{ts}_conf{score:.2f}.jpg"
save_dir = os.path.join(save_root, f"class{save_class}")
cv2.imwrite(os.path.join(save_dir, filename), frame)
print(f"[SAVE] class{save_class}/{filename}")
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
# =====================================================
# main
# =====================================================
if __name__ == "__main__":
model_path = "zdb_cls.rknn"
roi_file = "./roi_coordinates/zdb_roi.txt"
rtsp_url = "rtsp://admin:XJ123456@192.168.250.60:554/streaming/channels/101"
run_rtsp_classification_and_save(
model_path=model_path,
roi_file=roi_file,
rtsp_url=rtsp_url,
save_root="clsimg",
stable_frames=3,
save_mode="all" # 改成 "stable" 只存稳定结果
)

BIN
zdb_cls/zdb_cls.rknn Normal file

Binary file not shown.

282
zdb_cls/zdb_cls_rknn.py Normal file
View File

@ -0,0 +1,282 @@
import os
import cv2
import numpy as np
from rknnlite.api import RKNNLite
from collections import deque
class StableClassJudge:
"""
连续三帧稳定判决器:
- class0 / class1 连续 3 帧 -> 输出
- class2 -> 清空计数,重新统计
"""
def __init__(self, stable_frames=3, ignore_class=2):
self.stable_frames = stable_frames
self.ignore_class = ignore_class
self.buffer = deque(maxlen=stable_frames)
def reset(self):
self.buffer.clear()
def update(self, class_id):
"""
输入单帧分类结果
返回:
- None尚未稳定
- class_id稳定输出结果
"""
# 遇到 class2直接清空重新计数
if class_id == self.ignore_class:
self.reset()
return None
self.buffer.append(class_id)
# 缓冲未满
if len(self.buffer) < self.stable_frames:
return None
# 三帧完全一致
if len(set(self.buffer)) == 1:
stable_class = self.buffer[0]
self.reset() # 输出一次后重新计数(防止重复触发)
return stable_class
return None
# ---------------------------
# 三分类映射模具车1是小的模具车2是大的
# ---------------------------
CLASS_NAMES = {
0: "模具车1",
1: "模具车2",
2: "有遮挡"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
def init_rknn_model(model_path):
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功: {model_path}")
return rknn
# ---------------------------
# 预处理
# ---------------------------
def letterbox(image, new_size=640, color=(114, 114, 114)):
h, w = image.shape[:2]
scale = min(new_size / h, new_size / w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
top = (new_size - nh) // 2
left = (new_size - nw) // 2
new_img[top:top + nh, left:left + nw] = resized
return new_img
def resize_stretch(image, size=640):
return cv2.resize(image, (size, size))
def preprocess_image_for_rknn(
img,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if resize_mode == "letterbox":
img_box = letterbox(img, new_size=size)
else:
img_box = resize_stretch(img, size=size)
if to_rgb:
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
img_f = img_box.astype(np.float32)
if normalize:
img_f /= 255.0
if layout == "NHWC":
out = np.expand_dims(img_f, axis=0)
else:
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
return np.ascontiguousarray(out)
# ---------------------------
# 单次 RKNN 推理(三分类)
# ---------------------------
def rknn_classify_preprocessed(input_tensor, model_path):
rknn = init_rknn_model(model_path)
input_tensor = np.ascontiguousarray(input_tensor.astype(np.float32))
outs = rknn.inference([input_tensor])
pred = outs[0].reshape(-1).astype(float) # shape = (3,)
class_id = int(np.argmax(pred))
return class_id, pred
# ---------------------------
# ROI
# ---------------------------
def load_single_roi(txt_path):
if not os.path.exists(txt_path):
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
with open(txt_path) as f:
for line in f:
s = line.strip()
if not s:
continue
x, y, w, h = map(int, s.split(','))
return (x, y, w, h)
raise RuntimeError("ROI 文件为空")
def crop_and_return_roi(img, roi):
x, y, w, h = roi
h_img, w_img = img.shape[:2]
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
raise RuntimeError(f"ROI 超出图像范围: {roi}")
return img[y:y + h, x:x + w]
# ---------------------------
# 单张图片推理(三分类)
# ---------------------------
def classify_single_image(
frame,
model_path,
roi_file,
size=640,
resize_mode="stretch",
to_rgb=True,
normalize=False,
layout="NHWC"
):
if frame is None:
raise FileNotFoundError("输入帧为空")
roi = load_single_roi(roi_file)
roi_img = crop_and_return_roi(frame, roi)
input_tensor = preprocess_image_for_rknn(
roi_img,
size=size,
resize_mode=resize_mode,
to_rgb=to_rgb,
normalize=normalize,
layout=layout
)
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
return {
"class_id": class_id,
"class": class_name,
"score": round(float(probs[class_id]), 4),
"raw": probs.tolist()
}
# ---------------------------
# 示例调用
# ---------------------------
if __name__ == "__main__":
model_path = "zdb_cls.rknn"
roi_file = "./roi_coordinates/zdb_roi.txt"
image_path = "./test_image/test.png"
frame = cv2.imread(image_path)
if frame is None:
raise FileNotFoundError(f"无法读取图片: {image_path}")
result = classify_single_image(frame, model_path, roi_file)
print("[RESULT]", result)
# ---------------------------
# 示例判断逻辑
'''
import cv2
from muju_cls_rknn import classify_single_image,StableClassJudge,CLASS_NAMES
def run_stable_classification_loop(
model_path,
roi_file,
image_source,
stable_frames=3
):
"""
image_source:
- cv2.VideoCapture
"""
judge = StableClassJudge(
stable_frames=stable_frames,
ignore_class=2 # 有遮挡
)
cap = image_source
if not hasattr(cap, "read"):
raise TypeError("image_source 必须是 cv2.VideoCapture")
while True:
ret, frame = cap.read()
# 上下左右翻转
frame = cv2.flip(frame, -1)
if not ret:
print("读取帧失败,退出")
break
result = classify_single_image(frame, model_path, roi_file)
class_id = result["class_id"]
class_name = result["class"]
score = result["score"]
print(f"[FRAME] {class_name} conf={score}")
stable = judge.update(class_id)
if stable is not None:
print(f"\n稳定输出: {CLASS_NAMES[stable]} \n")
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
'''
# ---------------------------