更新charge振捣判断
This commit is contained in:
25
1.txt
Normal file
25
1.txt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
(zjsh) teamhd@teamhd:~/Downloads/rkmpp/build/linux/aarch64$ sudo apt install gcc-aarch64-linux-gnu
|
||||||
|
Reading package lists... Done
|
||||||
|
Building dependency tree... Done
|
||||||
|
Reading state information... Done
|
||||||
|
Note, selecting 'gcc' instead of 'gcc-aarch64-linux-gnu'
|
||||||
|
Some packages could not be installed. This may mean that you have
|
||||||
|
requested an impossible situation or if you are using the unstable
|
||||||
|
distribution that some required packages have not yet been created
|
||||||
|
or been moved out of Incoming.
|
||||||
|
The following information may help to resolve the situation:
|
||||||
|
|
||||||
|
The following packages have unmet dependencies:
|
||||||
|
gcc-11 : Depends: cpp-11 (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
|
||||||
|
Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libasan6 : Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libcc1-0 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libgcc-11-dev : Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libhwasan0 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libitm1 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
|
||||||
|
liblsan0 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libtsan0 : Depends: gcc-11-base (= 11.4.0-1ubuntu1~22.04.3) but 11.4.0-1ubuntu1~22.04 is to be installed
|
||||||
|
libubsan1 : Depends: gcc-12-base (= 12.3.0-1ubuntu1~22.04.3) but 12.3.0-1ubuntu1~22.04 is to be installed
|
||||||
|
E: Unable to correct problems, you have held broken packages.
|
||||||
|
(zjsh) teamhd@teamhd:~/Downloads/rkmpp/build/linux/aarch64$
|
||||||
|
|
||||||
BIN
charge_3cls/charge_cls.rknn
Normal file
BIN
charge_3cls/charge_cls.rknn
Normal file
Binary file not shown.
198
charge_3cls/charge_cls_rknn.py
Normal file
198
charge_3cls/charge_cls_rknn.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
class StableClassJudge:
|
||||||
|
"""
|
||||||
|
连续三帧稳定判决器:
|
||||||
|
- class0 / class1 连续 3 帧 -> 输出
|
||||||
|
- class2 -> 清空计数,重新统计
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stable_frames=3, ignore_class=2):
|
||||||
|
self.stable_frames = stable_frames
|
||||||
|
self.ignore_class = ignore_class
|
||||||
|
self.buffer = deque(maxlen=stable_frames)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.buffer.clear()
|
||||||
|
|
||||||
|
def update(self, class_id):
|
||||||
|
"""
|
||||||
|
输入单帧分类结果
|
||||||
|
返回:
|
||||||
|
- None:尚未稳定
|
||||||
|
- class_id:稳定输出结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 遇到 class2,直接清空重新计数
|
||||||
|
if class_id == self.ignore_class:
|
||||||
|
self.reset()
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.buffer.append(class_id)
|
||||||
|
|
||||||
|
# 缓冲未满
|
||||||
|
if len(self.buffer) < self.stable_frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 三帧完全一致
|
||||||
|
if len(set(self.buffer)) == 1:
|
||||||
|
stable_class = self.buffer[0]
|
||||||
|
self.reset() # 输出一次后重新计数(防止重复触发)
|
||||||
|
return stable_class
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 三分类映射
|
||||||
|
# ---------------------------
|
||||||
|
CLASS_NAMES = {
|
||||||
|
0: "插好",
|
||||||
|
1: "未插好",
|
||||||
|
2: "有遮挡"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# RKNN 全局实例(只加载一次)
|
||||||
|
# ---------------------------
|
||||||
|
_global_rknn = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_rknn_model(model_path):
|
||||||
|
global _global_rknn
|
||||||
|
if _global_rknn is not None:
|
||||||
|
return _global_rknn
|
||||||
|
|
||||||
|
rknn = RKNNLite(verbose=False)
|
||||||
|
ret = rknn.load_rknn(model_path)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||||||
|
|
||||||
|
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Init runtime failed: {ret}")
|
||||||
|
|
||||||
|
_global_rknn = rknn
|
||||||
|
print(f"[INFO] RKNN 模型加载成功:{model_path}")
|
||||||
|
return rknn
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 预处理
|
||||||
|
# ---------------------------
|
||||||
|
def letterbox(image, new_size=640, color=(114, 114, 114)):
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
scale = min(new_size / h, new_size / w)
|
||||||
|
nh, nw = int(h * scale), int(w * scale)
|
||||||
|
resized = cv2.resize(image, (nw, nh))
|
||||||
|
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
|
||||||
|
top = (new_size - nh) // 2
|
||||||
|
left = (new_size - nw) // 2
|
||||||
|
new_img[top:top + nh, left:left + nw] = resized
|
||||||
|
return new_img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_stretch(image, size=640):
|
||||||
|
return cv2.resize(image, (size, size))
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image_for_rknn(
|
||||||
|
img,
|
||||||
|
size=640,
|
||||||
|
resize_mode="stretch",
|
||||||
|
to_rgb=True,
|
||||||
|
normalize=False,
|
||||||
|
layout="NHWC"
|
||||||
|
):
|
||||||
|
if resize_mode == "letterbox":
|
||||||
|
img_box = letterbox(img, new_size=size)
|
||||||
|
else:
|
||||||
|
img_box = resize_stretch(img, size=size)
|
||||||
|
|
||||||
|
if to_rgb:
|
||||||
|
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
img_f = img_box.astype(np.float32)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
img_f /= 255.0
|
||||||
|
|
||||||
|
if layout == "NHWC":
|
||||||
|
out = np.expand_dims(img_f, axis=0)
|
||||||
|
else:
|
||||||
|
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
|
||||||
|
|
||||||
|
return np.ascontiguousarray(out)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 单次 RKNN 推理(三分类)
|
||||||
|
# ---------------------------
|
||||||
|
def rknn_classify_preprocessed(input_tensor, model_path):
|
||||||
|
rknn = init_rknn_model(model_path)
|
||||||
|
|
||||||
|
outs = rknn.inference([input_tensor])
|
||||||
|
logits = outs[0].reshape(-1).astype(np.float32) # shape = (3,)
|
||||||
|
|
||||||
|
# softmax
|
||||||
|
exp = np.exp(logits - np.max(logits))
|
||||||
|
probs = exp / np.sum(exp)
|
||||||
|
|
||||||
|
class_id = int(np.argmax(probs))
|
||||||
|
return class_id, probs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 单张图片推理(三分类)- 已移除 ROI 逻辑
|
||||||
|
# ---------------------------
|
||||||
|
def classify_single_image(
|
||||||
|
frame,
|
||||||
|
model_path,
|
||||||
|
size=640,
|
||||||
|
resize_mode="stretch",
|
||||||
|
to_rgb=True,
|
||||||
|
normalize=False,
|
||||||
|
layout="NHWC"
|
||||||
|
):
|
||||||
|
if frame is None:
|
||||||
|
raise FileNotFoundError("❌ 输入帧为空")
|
||||||
|
|
||||||
|
# 直接使用整图,不再裁剪
|
||||||
|
input_tensor = preprocess_image_for_rknn(
|
||||||
|
frame,
|
||||||
|
size=size,
|
||||||
|
resize_mode=resize_mode,
|
||||||
|
to_rgb=to_rgb,
|
||||||
|
normalize=normalize,
|
||||||
|
layout=layout
|
||||||
|
)
|
||||||
|
|
||||||
|
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||||||
|
class_name = CLASS_NAMES.get(class_id, f"未知类别 ({class_id})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"class_id": class_id,
|
||||||
|
"class": class_name,
|
||||||
|
"score": round(float(probs[class_id]), 4),
|
||||||
|
"raw": probs.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 示例调用
|
||||||
|
# ---------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_path = "charge_cls.rknn"
|
||||||
|
# roi_file 已移除
|
||||||
|
image_path = "class2.png"
|
||||||
|
|
||||||
|
frame = cv2.imread(image_path)
|
||||||
|
if frame is None:
|
||||||
|
raise FileNotFoundError(f"❌ 无法读取图片:{image_path}")
|
||||||
|
|
||||||
|
# 调用
|
||||||
|
result = classify_single_image(frame, model_path)
|
||||||
|
print("[RESULT]", result)
|
||||||
BIN
charge_3cls/class1.png
Normal file
BIN
charge_3cls/class1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.3 MiB |
BIN
charge_3cls/class2.png
Normal file
BIN
charge_3cls/class2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.3 MiB |
170
image_01.py
Normal file
170
image_01.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
|
||||||
|
# ================== 配置区域 ==================
|
||||||
|
RTSP_URL = "rtsp://admin:XJ123456@192.168.10.50:554/streaming/channels/101"
|
||||||
|
|
||||||
|
SAVE_INTERVAL = 10 # 每 20 帧尝试一次处理
|
||||||
|
DISPLAY_STREAM = False # 是否显示画面
|
||||||
|
|
||||||
|
# --- 灰图过滤配置 ---
|
||||||
|
GRAY_LOWER = 70
|
||||||
|
GRAY_UPPER = 230
|
||||||
|
GRAY_RATIO_THRESHOLD = 0.7 # 灰色像素占比超过此值视为灰图
|
||||||
|
|
||||||
|
# --- SSIM 去重配置 ---
|
||||||
|
SSIM_THRESHOLD = 0.9 # 相似度超过此值视为重复图片
|
||||||
|
|
||||||
|
# --- 保存目录 ---
|
||||||
|
OUTPUT_DIR = "camera02_save"
|
||||||
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# ================== 辅助函数 ==================
|
||||||
|
|
||||||
|
def is_large_gray(image):
|
||||||
|
"""
|
||||||
|
检测图像是否为大面积灰图
|
||||||
|
image: PIL Image 对象 (RGB)
|
||||||
|
"""
|
||||||
|
arr = np.array(image)
|
||||||
|
# 检查 R, G, B 三个通道是否都在灰色范围内
|
||||||
|
gray_mask = (
|
||||||
|
(arr[:, :, 0] >= GRAY_LOWER) & (arr[:, :, 0] <= GRAY_UPPER) &
|
||||||
|
(arr[:, :, 1] >= GRAY_LOWER) & (arr[:, :, 1] <= GRAY_UPPER) &
|
||||||
|
(arr[:, :, 2] >= GRAY_LOWER) & (arr[:, :, 2] <= GRAY_UPPER)
|
||||||
|
)
|
||||||
|
# 计算灰色像素占比
|
||||||
|
return np.mean(gray_mask) > GRAY_RATIO_THRESHOLD
|
||||||
|
|
||||||
|
def open_camera(url):
|
||||||
|
cap = cv2.VideoCapture(url)
|
||||||
|
if not cap.isOpened():
|
||||||
|
return None
|
||||||
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||||
|
cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 5000)
|
||||||
|
cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 5000)
|
||||||
|
return cap
|
||||||
|
|
||||||
|
def get_valid_frame(cap):
|
||||||
|
"""安全读取一帧"""
|
||||||
|
if cap is None or not cap.isOpened():
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret or frame is None or frame.size == 0:
|
||||||
|
return None, False
|
||||||
|
return frame, True
|
||||||
|
|
||||||
|
# ================== 主程序 ==================
|
||||||
|
|
||||||
|
print(f"✅ 正在连接摄像头: {RTSP_URL} ...")
|
||||||
|
cap = open_camera(RTSP_URL)
|
||||||
|
|
||||||
|
if cap is None:
|
||||||
|
print(f"❌ 连接失败!请检查 IP、账号密码或网络。")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("📡 摄像头已连接,开始采集...")
|
||||||
|
print(f" - 保存目录: {os.path.abspath(OUTPUT_DIR)}")
|
||||||
|
print(f" - 灰度阈值: [{GRAY_LOWER}, {GRAY_UPPER}], 占比 > {GRAY_RATIO_THRESHOLD}")
|
||||||
|
print(f" - SSIM 去重阈值: > {SSIM_THRESHOLD}")
|
||||||
|
print(" - 按 Ctrl+C 停止\n")
|
||||||
|
|
||||||
|
frame_count = 0
|
||||||
|
saved_count = 0
|
||||||
|
last_gray_frame = None # 用于存储上一帧的灰度图以计算 SSIM
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
frame, ret = get_valid_frame(cap)
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
print("⚠️ 读取帧失败,尝试重连...")
|
||||||
|
time.sleep(2)
|
||||||
|
cap.release()
|
||||||
|
cap = open_camera(RTSP_URL)
|
||||||
|
if cap is None:
|
||||||
|
print("❌ 重连失败,退出程序。")
|
||||||
|
break
|
||||||
|
last_gray_frame = None # 重连后重置对比帧
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
# 间隔采样:不是指定间隔的帧直接跳过(可选显示)
|
||||||
|
if frame_count % SAVE_INTERVAL != 0:
|
||||||
|
if DISPLAY_STREAM:
|
||||||
|
cv2.imshow("Camera Stream", frame)
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 步骤 1: 灰图检测 ---
|
||||||
|
try:
|
||||||
|
# 转 RGB 供 PIL 使用
|
||||||
|
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_img = Image.fromarray(img_rgb)
|
||||||
|
|
||||||
|
if is_large_gray(pil_img):
|
||||||
|
print(f"⏭️ 跳过:检测到灰图 (帧 {frame_count})")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 灰图检测异常: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 步骤 2: SSIM 相似性去重 ---
|
||||||
|
try:
|
||||||
|
gray_curr = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
if last_gray_frame is not None:
|
||||||
|
# 计算结构相似性
|
||||||
|
# resize 确保尺寸一致(防止分辨率动态变化导致报错)
|
||||||
|
if last_gray_frame.shape != gray_curr.shape:
|
||||||
|
last_gray_frame = cv2.resize(last_gray_frame, (gray_curr.shape[1], gray_curr.shape[0]))
|
||||||
|
|
||||||
|
sim_score = ssim(last_gray_frame, gray_curr)
|
||||||
|
|
||||||
|
if sim_score > SSIM_THRESHOLD:
|
||||||
|
print(f"⏭️ 跳过:画面重复 (SSIM={sim_score:.3f}, 帧 {frame_count})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 更新参考帧
|
||||||
|
last_gray_frame = gray_curr.copy()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ SSIM 计算异常: {e}")
|
||||||
|
# 如果计算出错,可以选择跳过或强制保存,这里选择跳过以防崩溃
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 步骤 3: 保存图片 ---
|
||||||
|
ts = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
ms = int((time.time() % 1) * 1000)
|
||||||
|
filename = f"img_{ts}_{ms:03d}.png"
|
||||||
|
save_path = os.path.join(OUTPUT_DIR, filename)
|
||||||
|
|
||||||
|
# 如果需要旋转/翻转,在这里操作 (例如翻转 180 度)
|
||||||
|
# frame = cv2.flip(frame, -1)
|
||||||
|
|
||||||
|
cv2.imwrite(save_path, frame)
|
||||||
|
saved_count += 1
|
||||||
|
print(f"✅ [{saved_count}] 已保存有效图片: {filename}")
|
||||||
|
|
||||||
|
if DISPLAY_STREAM:
|
||||||
|
cv2.putText(frame, f"Saved: {filename}", (10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||||
|
cv2.imshow("Camera Stream", frame)
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n🛑 用户手动停止")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if cap:
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
print(f"🔚 程序结束。共保存 {saved_count} 张图片。")
|
||||||
267803
miniforge/Miniforge3-Linux-aarch64.sh
Normal file
267803
miniforge/Miniforge3-Linux-aarch64.sh
Normal file
File diff suppressed because one or more lines are too long
BIN
rknn_save/yiliao_cls_60_0123.rknn
Normal file
BIN
rknn_save/yiliao_cls_60_0123.rknn
Normal file
Binary file not shown.
@ -75,10 +75,12 @@ def largest_intersect_cc(mask_bin, bbox):
|
|||||||
# RANSAC 直线拟合(核心新增)
|
# RANSAC 直线拟合(核心新增)
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
def fit_line_ransac(pts, max_dist=2.5, min_inliers_ratio=0.6, iters=100):
|
def fit_line_ransac(pts, max_dist=2.5, min_inliers_ratio=0.6, iters=100):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
拟合 x = m*y + b
|
拟合 x = m*y + b
|
||||||
pts: Nx2 -> [x,y]
|
pts: Nx2 -> [x,y]
|
||||||
"""
|
"""
|
||||||
|
np.random.seed(42)
|
||||||
if len(pts) < 10:
|
if len(pts) < 10:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
117
zdb_cls/main.py
Normal file
117
zdb_cls/main.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
|
||||||
|
# classify_single_image, StableClassJudge, CLASS_NAMES 已在 muju_cls_rknn 中定义
|
||||||
|
from zdb_cls_rknn import classify_single_image, StableClassJudge, CLASS_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
def run_stable_classification_loop(
|
||||||
|
model_path,
|
||||||
|
roi_file,
|
||||||
|
image_source,
|
||||||
|
stable_frames=3,
|
||||||
|
display_scale=0.5, # 显示缩放比例(0.5 = 显示为原来 50%)
|
||||||
|
show_window=False # 是否显示窗口
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
image_source: cv2.VideoCapture 对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
judge = StableClassJudge(
|
||||||
|
stable_frames=stable_frames,
|
||||||
|
ignore_class=2 # 忽略“有遮挡”类别参与稳定判断
|
||||||
|
)
|
||||||
|
|
||||||
|
cap = image_source
|
||||||
|
if not hasattr(cap, "read"):
|
||||||
|
raise TypeError("image_source 必须是 cv2.VideoCapture 实例")
|
||||||
|
|
||||||
|
# 可选:创建可缩放窗口
|
||||||
|
if show_window:
|
||||||
|
cv2.namedWindow("RTSP Stream - Press 'q' to quit", cv2.WINDOW_NORMAL)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print("无法读取视频帧(可能是流断开或结束)")
|
||||||
|
break
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 单帧推理
|
||||||
|
# ---------------------------
|
||||||
|
result = classify_single_image(frame, model_path, roi_file)
|
||||||
|
|
||||||
|
class_id = result["class_id"]
|
||||||
|
class_name = result["class"]
|
||||||
|
score = result["score"]
|
||||||
|
|
||||||
|
print(f"[FRAME] {class_name} | conf={score:.3f}")
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 稳定判断
|
||||||
|
# ---------------------------
|
||||||
|
stable_class_id = judge.update(class_id)
|
||||||
|
|
||||||
|
if stable_class_id is not None:
|
||||||
|
print(f"\n稳定输出: {CLASS_NAMES[stable_class_id]}\n")
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 显示画面(缩小窗口)
|
||||||
|
# ---------------------------
|
||||||
|
if show_window:
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
display_frame = cv2.resize(
|
||||||
|
frame,
|
||||||
|
(int(w * display_scale), int(h * display_scale)),
|
||||||
|
interpolation=cv2.INTER_AREA
|
||||||
|
)
|
||||||
|
|
||||||
|
cv2.imshow("RTSP Stream - Press 'q' to quit", display_frame)
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# ---------------------------
|
||||||
|
# 配置参数
|
||||||
|
# ---------------------------
|
||||||
|
MODEL_PATH = "zdb_cls.rknn"
|
||||||
|
ROI_FILE = "./roi_coordinates/zdb_roi.txt"
|
||||||
|
RTSP_URL = "rtsp://admin:XJ123456@192.168.250.60:554/streaming/channels/101"
|
||||||
|
|
||||||
|
STABLE_FRAMES = 3
|
||||||
|
DISPLAY_SCALE = 0.5 # 显示窗口缩放比例
|
||||||
|
SHOW_WINDOW = False # 部署时改成 False,测试的时候打开
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 打开 RTSP 视频流
|
||||||
|
# ---------------------------
|
||||||
|
print(f"正在连接 RTSP 流: {RTSP_URL}")
|
||||||
|
cap = cv2.VideoCapture(RTSP_URL)
|
||||||
|
|
||||||
|
# 降低 RTSP 延迟(部分摄像头支持)
|
||||||
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||||
|
|
||||||
|
if not cap.isOpened():
|
||||||
|
print("无法打开 RTSP 流,请检查网络、账号密码或 URL")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("RTSP 流连接成功,开始推理...")
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 启动稳定分类循环三帧稳定判断
|
||||||
|
# ---------------------------
|
||||||
|
run_stable_classification_loop(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
roi_file=ROI_FILE,
|
||||||
|
image_source=cap,
|
||||||
|
stable_frames=STABLE_FRAMES,
|
||||||
|
display_scale=DISPLAY_SCALE,
|
||||||
|
show_window=SHOW_WINDOW
|
||||||
|
)
|
||||||
|
|
||||||
1
zdb_cls/roi_coordinates/zdb_roi.txt
Normal file
1
zdb_cls/roi_coordinates/zdb_roi.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
2,880,385,200
|
||||||
BIN
zdb_cls/test.png
Normal file
BIN
zdb_cls/test.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
275
zdb_cls/test_imagesave.py
Normal file
275
zdb_cls/test_imagesave.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import deque
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# 稳定判决器
|
||||||
|
# =====================================================
|
||||||
|
class StableClassJudge:
|
||||||
|
"""
|
||||||
|
连续 N 帧稳定判决:
|
||||||
|
- class0 / class1 连续 N 帧 -> 输出
|
||||||
|
- class2 -> 清空计数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stable_frames=3, ignore_class=2):
|
||||||
|
self.stable_frames = stable_frames
|
||||||
|
self.ignore_class = ignore_class
|
||||||
|
self.buffer = deque(maxlen=stable_frames)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.buffer.clear()
|
||||||
|
|
||||||
|
def update(self, class_id):
|
||||||
|
if class_id == self.ignore_class:
|
||||||
|
self.reset()
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.buffer.append(class_id)
|
||||||
|
|
||||||
|
if len(self.buffer) < self.stable_frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(set(self.buffer)) == 1:
|
||||||
|
stable = self.buffer[0]
|
||||||
|
self.reset()
|
||||||
|
return stable
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# 类别定义
|
||||||
|
# =====================================================
|
||||||
|
CLASS_NAMES = {
|
||||||
|
0: "未插入振捣棒",
|
||||||
|
1: "安全插入振捣棒",
|
||||||
|
2: "有遮挡"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# RKNN 全局实例
|
||||||
|
# =====================================================
|
||||||
|
_global_rknn = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_rknn_model(model_path):
|
||||||
|
global _global_rknn
|
||||||
|
if _global_rknn is not None:
|
||||||
|
return _global_rknn
|
||||||
|
|
||||||
|
rknn = RKNNLite(verbose=False)
|
||||||
|
|
||||||
|
ret = rknn.load_rknn(model_path)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||||||
|
|
||||||
|
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Init runtime failed: {ret}")
|
||||||
|
|
||||||
|
_global_rknn = rknn
|
||||||
|
print(f"[INFO] RKNN 模型加载成功: {model_path}")
|
||||||
|
return rknn
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# 图像预处理
|
||||||
|
# =====================================================
|
||||||
|
def letterbox(image, new_size=640, color=(114, 114, 114)):
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
scale = min(new_size / h, new_size / w)
|
||||||
|
nh, nw = int(h * scale), int(w * scale)
|
||||||
|
|
||||||
|
resized = cv2.resize(image, (nw, nh))
|
||||||
|
canvas = np.full((new_size, new_size, 3), color, dtype=np.uint8)
|
||||||
|
|
||||||
|
top = (new_size - nh) // 2
|
||||||
|
left = (new_size - nw) // 2
|
||||||
|
canvas[top:top + nh, left:left + nw] = resized
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
|
def resize_stretch(image, size=640):
|
||||||
|
return cv2.resize(image, (size, size))
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image_for_rknn(
|
||||||
|
img,
|
||||||
|
size=640,
|
||||||
|
resize_mode="stretch",
|
||||||
|
to_rgb=True,
|
||||||
|
normalize=False,
|
||||||
|
layout="NHWC"
|
||||||
|
):
|
||||||
|
if resize_mode == "letterbox":
|
||||||
|
img = letterbox(img, size)
|
||||||
|
else:
|
||||||
|
img = resize_stretch(img, size)
|
||||||
|
|
||||||
|
if to_rgb:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
img /= 255.0
|
||||||
|
|
||||||
|
if layout == "NHWC":
|
||||||
|
img = np.expand_dims(img, axis=0)
|
||||||
|
else:
|
||||||
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
||||||
|
|
||||||
|
return np.ascontiguousarray(img)
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# RKNN 单次推理
|
||||||
|
# =====================================================
|
||||||
|
def rknn_classify_preprocessed(input_tensor, model_path):
|
||||||
|
rknn = init_rknn_model(model_path)
|
||||||
|
outs = rknn.inference([input_tensor])
|
||||||
|
probs = outs[0].reshape(-1).astype(float)
|
||||||
|
class_id = int(np.argmax(probs))
|
||||||
|
return class_id, probs
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# ROI 处理
|
||||||
|
# =====================================================
|
||||||
|
def load_single_roi(txt_path):
|
||||||
|
if not os.path.exists(txt_path):
|
||||||
|
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
|
||||||
|
|
||||||
|
with open(txt_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
x, y, w, h = map(int, line.split(","))
|
||||||
|
return (x, y, w, h)
|
||||||
|
|
||||||
|
raise RuntimeError("ROI 文件为空")
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_return_roi(img, roi):
|
||||||
|
x, y, w, h = roi
|
||||||
|
H, W = img.shape[:2]
|
||||||
|
|
||||||
|
if x < 0 or y < 0 or x + w > W or y + h > H:
|
||||||
|
raise RuntimeError(f"ROI 超出图像范围: {roi}")
|
||||||
|
|
||||||
|
return img[y:y + h, x:x + w]
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# 单帧分类
|
||||||
|
# =====================================================
|
||||||
|
def classify_single_image(frame, model_path, roi_file):
|
||||||
|
roi = load_single_roi(roi_file)
|
||||||
|
roi_img = crop_and_return_roi(frame, roi)
|
||||||
|
|
||||||
|
input_tensor = preprocess_image_for_rknn(
|
||||||
|
roi_img,
|
||||||
|
size=640,
|
||||||
|
resize_mode="stretch",
|
||||||
|
to_rgb=True,
|
||||||
|
normalize=False,
|
||||||
|
layout="NHWC"
|
||||||
|
)
|
||||||
|
|
||||||
|
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"class_id": class_id,
|
||||||
|
"class": CLASS_NAMES[class_id],
|
||||||
|
"score": round(float(probs[class_id]), 4),
|
||||||
|
"raw": probs.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# RTSP 推理 + 保存分类结果
|
||||||
|
# =====================================================
|
||||||
|
def run_rtsp_classification_and_save(
|
||||||
|
model_path,
|
||||||
|
roi_file,
|
||||||
|
rtsp_url,
|
||||||
|
save_root="clsimg",
|
||||||
|
stable_frames=3,
|
||||||
|
save_mode="all" # all / stable
|
||||||
|
):
|
||||||
|
for cid in CLASS_NAMES.keys():
|
||||||
|
os.makedirs(os.path.join(save_root, f"class{cid}"), exist_ok=True)
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(rtsp_url)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise RuntimeError(f"无法打开 RTSP: {rtsp_url}")
|
||||||
|
|
||||||
|
judge = StableClassJudge(stable_frames=stable_frames, ignore_class=2)
|
||||||
|
|
||||||
|
print("[INFO] RTSP 推理开始")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print("[WARN] RTSP 读帧失败")
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame = cv2.flip(frame, -1)
|
||||||
|
|
||||||
|
result = classify_single_image(frame, model_path, roi_file)
|
||||||
|
class_id = result["class_id"]
|
||||||
|
score = result["score"]
|
||||||
|
|
||||||
|
print(f"[FRAME] {result['class']} conf={score}")
|
||||||
|
|
||||||
|
stable = judge.update(class_id)
|
||||||
|
|
||||||
|
save_flag = False
|
||||||
|
save_class = class_id
|
||||||
|
|
||||||
|
if save_mode == "all":
|
||||||
|
save_flag = True
|
||||||
|
elif save_mode == "stable" and stable is not None:
|
||||||
|
save_flag = True
|
||||||
|
save_class = stable
|
||||||
|
|
||||||
|
if save_flag:
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
filename = f"{ts}_conf{score:.2f}.jpg"
|
||||||
|
save_dir = os.path.join(save_root, f"class{save_class}")
|
||||||
|
cv2.imwrite(os.path.join(save_dir, filename), frame)
|
||||||
|
print(f"[SAVE] class{save_class}/{filename}")
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# main
|
||||||
|
# =====================================================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_path = "zdb_cls.rknn"
|
||||||
|
roi_file = "./roi_coordinates/zdb_roi.txt"
|
||||||
|
|
||||||
|
rtsp_url = "rtsp://admin:XJ123456@192.168.250.60:554/streaming/channels/101"
|
||||||
|
|
||||||
|
run_rtsp_classification_and_save(
|
||||||
|
model_path=model_path,
|
||||||
|
roi_file=roi_file,
|
||||||
|
rtsp_url=rtsp_url,
|
||||||
|
save_root="clsimg",
|
||||||
|
stable_frames=3,
|
||||||
|
save_mode="all" # 改成 "stable" 只存稳定结果
|
||||||
|
)
|
||||||
|
|
||||||
BIN
zdb_cls/zdb_cls.rknn
Normal file
BIN
zdb_cls/zdb_cls.rknn
Normal file
Binary file not shown.
282
zdb_cls/zdb_cls_rknn.py
Normal file
282
zdb_cls/zdb_cls_rknn.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
class StableClassJudge:
|
||||||
|
"""
|
||||||
|
连续三帧稳定判决器:
|
||||||
|
- class0 / class1 连续 3 帧 -> 输出
|
||||||
|
- class2 -> 清空计数,重新统计
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stable_frames=3, ignore_class=2):
|
||||||
|
self.stable_frames = stable_frames
|
||||||
|
self.ignore_class = ignore_class
|
||||||
|
self.buffer = deque(maxlen=stable_frames)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.buffer.clear()
|
||||||
|
|
||||||
|
def update(self, class_id):
|
||||||
|
"""
|
||||||
|
输入单帧分类结果
|
||||||
|
返回:
|
||||||
|
- None:尚未稳定
|
||||||
|
- class_id:稳定输出结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 遇到 class2,直接清空重新计数
|
||||||
|
if class_id == self.ignore_class:
|
||||||
|
self.reset()
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.buffer.append(class_id)
|
||||||
|
|
||||||
|
# 缓冲未满
|
||||||
|
if len(self.buffer) < self.stable_frames:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 三帧完全一致
|
||||||
|
if len(set(self.buffer)) == 1:
|
||||||
|
stable_class = self.buffer[0]
|
||||||
|
self.reset() # 输出一次后重新计数(防止重复触发)
|
||||||
|
return stable_class
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 三分类映射,模具车1是小的,模具车2是大的
|
||||||
|
# ---------------------------
|
||||||
|
CLASS_NAMES = {
|
||||||
|
0: "模具车1",
|
||||||
|
1: "模具车2",
|
||||||
|
2: "有遮挡"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# RKNN 全局实例(只加载一次)
|
||||||
|
# ---------------------------
|
||||||
|
_global_rknn = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_rknn_model(model_path):
|
||||||
|
global _global_rknn
|
||||||
|
if _global_rknn is not None:
|
||||||
|
return _global_rknn
|
||||||
|
|
||||||
|
rknn = RKNNLite(verbose=False)
|
||||||
|
ret = rknn.load_rknn(model_path)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||||||
|
|
||||||
|
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Init runtime failed: {ret}")
|
||||||
|
|
||||||
|
_global_rknn = rknn
|
||||||
|
print(f"[INFO] RKNN 模型加载成功: {model_path}")
|
||||||
|
return rknn
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 预处理
|
||||||
|
# ---------------------------
|
||||||
|
def letterbox(image, new_size=640, color=(114, 114, 114)):
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
scale = min(new_size / h, new_size / w)
|
||||||
|
nh, nw = int(h * scale), int(w * scale)
|
||||||
|
resized = cv2.resize(image, (nw, nh))
|
||||||
|
new_img = np.full((new_size, new_size, 3), color, dtype=np.uint8)
|
||||||
|
top = (new_size - nh) // 2
|
||||||
|
left = (new_size - nw) // 2
|
||||||
|
new_img[top:top + nh, left:left + nw] = resized
|
||||||
|
return new_img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_stretch(image, size=640):
|
||||||
|
return cv2.resize(image, (size, size))
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image_for_rknn(
|
||||||
|
img,
|
||||||
|
size=640,
|
||||||
|
resize_mode="stretch",
|
||||||
|
to_rgb=True,
|
||||||
|
normalize=False,
|
||||||
|
layout="NHWC"
|
||||||
|
):
|
||||||
|
if resize_mode == "letterbox":
|
||||||
|
img_box = letterbox(img, new_size=size)
|
||||||
|
else:
|
||||||
|
img_box = resize_stretch(img, size=size)
|
||||||
|
|
||||||
|
if to_rgb:
|
||||||
|
img_box = cv2.cvtColor(img_box, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
img_f = img_box.astype(np.float32)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
img_f /= 255.0
|
||||||
|
|
||||||
|
if layout == "NHWC":
|
||||||
|
out = np.expand_dims(img_f, axis=0)
|
||||||
|
else:
|
||||||
|
out = np.expand_dims(np.transpose(img_f, (2, 0, 1)), axis=0)
|
||||||
|
|
||||||
|
return np.ascontiguousarray(out)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 单次 RKNN 推理(三分类)
|
||||||
|
# ---------------------------
|
||||||
|
def rknn_classify_preprocessed(input_tensor, model_path):
|
||||||
|
rknn = init_rknn_model(model_path)
|
||||||
|
|
||||||
|
input_tensor = np.ascontiguousarray(input_tensor.astype(np.float32))
|
||||||
|
outs = rknn.inference([input_tensor])
|
||||||
|
|
||||||
|
pred = outs[0].reshape(-1).astype(float) # shape = (3,)
|
||||||
|
class_id = int(np.argmax(pred))
|
||||||
|
|
||||||
|
return class_id, pred
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# ROI
|
||||||
|
# ---------------------------
|
||||||
|
def load_single_roi(txt_path):
|
||||||
|
if not os.path.exists(txt_path):
|
||||||
|
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
|
||||||
|
|
||||||
|
with open(txt_path) as f:
|
||||||
|
for line in f:
|
||||||
|
s = line.strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
x, y, w, h = map(int, s.split(','))
|
||||||
|
return (x, y, w, h)
|
||||||
|
|
||||||
|
raise RuntimeError("ROI 文件为空")
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_return_roi(img, roi):
|
||||||
|
x, y, w, h = roi
|
||||||
|
h_img, w_img = img.shape[:2]
|
||||||
|
|
||||||
|
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
|
||||||
|
raise RuntimeError(f"ROI 超出图像范围: {roi}")
|
||||||
|
|
||||||
|
return img[y:y + h, x:x + w]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 单张图片推理(三分类)
|
||||||
|
# ---------------------------
|
||||||
|
def classify_single_image(
|
||||||
|
frame,
|
||||||
|
model_path,
|
||||||
|
roi_file,
|
||||||
|
size=640,
|
||||||
|
resize_mode="stretch",
|
||||||
|
to_rgb=True,
|
||||||
|
normalize=False,
|
||||||
|
layout="NHWC"
|
||||||
|
):
|
||||||
|
if frame is None:
|
||||||
|
raise FileNotFoundError("输入帧为空")
|
||||||
|
|
||||||
|
roi = load_single_roi(roi_file)
|
||||||
|
roi_img = crop_and_return_roi(frame, roi)
|
||||||
|
|
||||||
|
input_tensor = preprocess_image_for_rknn(
|
||||||
|
roi_img,
|
||||||
|
size=size,
|
||||||
|
resize_mode=resize_mode,
|
||||||
|
to_rgb=to_rgb,
|
||||||
|
normalize=normalize,
|
||||||
|
layout=layout
|
||||||
|
)
|
||||||
|
|
||||||
|
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||||||
|
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"class_id": class_id,
|
||||||
|
"class": class_name,
|
||||||
|
"score": round(float(probs[class_id]), 4),
|
||||||
|
"raw": probs.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 示例调用
|
||||||
|
# ---------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_path = "zdb_cls.rknn"
|
||||||
|
roi_file = "./roi_coordinates/zdb_roi.txt"
|
||||||
|
image_path = "./test_image/test.png"
|
||||||
|
|
||||||
|
frame = cv2.imread(image_path)
|
||||||
|
if frame is None:
|
||||||
|
raise FileNotFoundError(f"无法读取图片: {image_path}")
|
||||||
|
|
||||||
|
result = classify_single_image(frame, model_path, roi_file)
|
||||||
|
print("[RESULT]", result)
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 示例判断逻辑
|
||||||
|
'''
|
||||||
|
import cv2
|
||||||
|
from muju_cls_rknn import classify_single_image,StableClassJudge,CLASS_NAMES
|
||||||
|
|
||||||
|
def run_stable_classification_loop(
|
||||||
|
model_path,
|
||||||
|
roi_file,
|
||||||
|
image_source,
|
||||||
|
stable_frames=3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
image_source:
|
||||||
|
- cv2.VideoCapture
|
||||||
|
"""
|
||||||
|
judge = StableClassJudge(
|
||||||
|
stable_frames=stable_frames,
|
||||||
|
ignore_class=2 # 有遮挡
|
||||||
|
)
|
||||||
|
|
||||||
|
cap = image_source
|
||||||
|
if not hasattr(cap, "read"):
|
||||||
|
raise TypeError("image_source 必须是 cv2.VideoCapture")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
# 上下左右翻转
|
||||||
|
frame = cv2.flip(frame, -1)
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
print("读取帧失败,退出")
|
||||||
|
break
|
||||||
|
|
||||||
|
result = classify_single_image(frame, model_path, roi_file)
|
||||||
|
|
||||||
|
class_id = result["class_id"]
|
||||||
|
class_name = result["class"]
|
||||||
|
score = result["score"]
|
||||||
|
|
||||||
|
print(f"[FRAME] {class_name} conf={score}")
|
||||||
|
|
||||||
|
stable = judge.update(class_id)
|
||||||
|
|
||||||
|
if stable is not None:
|
||||||
|
print(f"\n稳定输出: {CLASS_NAMES[stable]} \n")
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
'''
|
||||||
|
# ---------------------------
|
||||||
Reference in New Issue
Block a user