更新液面diff代码
This commit is contained in:
275
muju_cls/test_imagesave.py
Normal file
275
muju_cls/test_imagesave.py
Normal file
@ -0,0 +1,275 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
# =====================================================
|
||||
# 稳定判决器
|
||||
# =====================================================
|
||||
class StableClassJudge:
|
||||
"""
|
||||
连续 N 帧稳定判决:
|
||||
- class0 / class1 连续 N 帧 -> 输出
|
||||
- class2 -> 清空计数
|
||||
"""
|
||||
|
||||
def __init__(self, stable_frames=3, ignore_class=2):
|
||||
self.stable_frames = stable_frames
|
||||
self.ignore_class = ignore_class
|
||||
self.buffer = deque(maxlen=stable_frames)
|
||||
|
||||
def reset(self):
|
||||
self.buffer.clear()
|
||||
|
||||
def update(self, class_id):
|
||||
if class_id == self.ignore_class:
|
||||
self.reset()
|
||||
return None
|
||||
|
||||
self.buffer.append(class_id)
|
||||
|
||||
if len(self.buffer) < self.stable_frames:
|
||||
return None
|
||||
|
||||
if len(set(self.buffer)) == 1:
|
||||
stable = self.buffer[0]
|
||||
self.reset()
|
||||
return stable
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =====================================================
|
||||
# 类别定义
|
||||
# =====================================================
|
||||
CLASS_NAMES = {
|
||||
0: "模具车1",
|
||||
1: "模具车2",
|
||||
2: "有遮挡"
|
||||
}
|
||||
|
||||
|
||||
# =====================================================
|
||||
# RKNN 全局实例
|
||||
# =====================================================
|
||||
_global_rknn = None
|
||||
|
||||
|
||||
def init_rknn_model(model_path):
|
||||
global _global_rknn
|
||||
if _global_rknn is not None:
|
||||
return _global_rknn
|
||||
|
||||
rknn = RKNNLite(verbose=False)
|
||||
|
||||
ret = rknn.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Load RKNN failed: {ret}")
|
||||
|
||||
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Init runtime failed: {ret}")
|
||||
|
||||
_global_rknn = rknn
|
||||
print(f"[INFO] RKNN 模型加载成功: {model_path}")
|
||||
return rknn
|
||||
|
||||
|
||||
# =====================================================
|
||||
# 图像预处理
|
||||
# =====================================================
|
||||
def letterbox(image, new_size=640, color=(114, 114, 114)):
|
||||
h, w = image.shape[:2]
|
||||
scale = min(new_size / h, new_size / w)
|
||||
nh, nw = int(h * scale), int(w * scale)
|
||||
|
||||
resized = cv2.resize(image, (nw, nh))
|
||||
canvas = np.full((new_size, new_size, 3), color, dtype=np.uint8)
|
||||
|
||||
top = (new_size - nh) // 2
|
||||
left = (new_size - nw) // 2
|
||||
canvas[top:top + nh, left:left + nw] = resized
|
||||
return canvas
|
||||
|
||||
|
||||
def resize_stretch(image, size=640):
|
||||
return cv2.resize(image, (size, size))
|
||||
|
||||
|
||||
def preprocess_image_for_rknn(
|
||||
img,
|
||||
size=640,
|
||||
resize_mode="stretch",
|
||||
to_rgb=True,
|
||||
normalize=False,
|
||||
layout="NHWC"
|
||||
):
|
||||
if resize_mode == "letterbox":
|
||||
img = letterbox(img, size)
|
||||
else:
|
||||
img = resize_stretch(img, size)
|
||||
|
||||
if to_rgb:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if normalize:
|
||||
img /= 255.0
|
||||
|
||||
if layout == "NHWC":
|
||||
img = np.expand_dims(img, axis=0)
|
||||
else:
|
||||
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
||||
|
||||
return np.ascontiguousarray(img)
|
||||
|
||||
|
||||
# =====================================================
|
||||
# RKNN 单次推理
|
||||
# =====================================================
|
||||
def rknn_classify_preprocessed(input_tensor, model_path):
|
||||
rknn = init_rknn_model(model_path)
|
||||
outs = rknn.inference([input_tensor])
|
||||
probs = outs[0].reshape(-1).astype(float)
|
||||
class_id = int(np.argmax(probs))
|
||||
return class_id, probs
|
||||
|
||||
|
||||
# =====================================================
|
||||
# ROI 处理
|
||||
# =====================================================
|
||||
def load_single_roi(txt_path):
|
||||
if not os.path.exists(txt_path):
|
||||
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
|
||||
|
||||
with open(txt_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
x, y, w, h = map(int, line.split(","))
|
||||
return (x, y, w, h)
|
||||
|
||||
raise RuntimeError("ROI 文件为空")
|
||||
|
||||
|
||||
def crop_and_return_roi(img, roi):
|
||||
x, y, w, h = roi
|
||||
H, W = img.shape[:2]
|
||||
|
||||
if x < 0 or y < 0 or x + w > W or y + h > H:
|
||||
raise RuntimeError(f"ROI 超出图像范围: {roi}")
|
||||
|
||||
return img[y:y + h, x:x + w]
|
||||
|
||||
|
||||
# =====================================================
|
||||
# 单帧分类
|
||||
# =====================================================
|
||||
def classify_single_image(frame, model_path, roi_file):
|
||||
roi = load_single_roi(roi_file)
|
||||
roi_img = crop_and_return_roi(frame, roi)
|
||||
|
||||
input_tensor = preprocess_image_for_rknn(
|
||||
roi_img,
|
||||
size=640,
|
||||
resize_mode="stretch",
|
||||
to_rgb=True,
|
||||
normalize=False,
|
||||
layout="NHWC"
|
||||
)
|
||||
|
||||
class_id, probs = rknn_classify_preprocessed(input_tensor, model_path)
|
||||
|
||||
return {
|
||||
"class_id": class_id,
|
||||
"class": CLASS_NAMES[class_id],
|
||||
"score": round(float(probs[class_id]), 4),
|
||||
"raw": probs.tolist()
|
||||
}
|
||||
|
||||
|
||||
# =====================================================
|
||||
# RTSP 推理 + 保存分类结果
|
||||
# =====================================================
|
||||
def run_rtsp_classification_and_save(
|
||||
model_path,
|
||||
roi_file,
|
||||
rtsp_url,
|
||||
save_root="clsimg",
|
||||
stable_frames=3,
|
||||
save_mode="all" # all / stable
|
||||
):
|
||||
for cid in CLASS_NAMES.keys():
|
||||
os.makedirs(os.path.join(save_root, f"class{cid}"), exist_ok=True)
|
||||
|
||||
cap = cv2.VideoCapture(rtsp_url)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"无法打开 RTSP: {rtsp_url}")
|
||||
|
||||
judge = StableClassJudge(stable_frames=stable_frames, ignore_class=2)
|
||||
|
||||
print("[INFO] RTSP 推理开始")
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("[WARN] RTSP 读帧失败")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
frame = cv2.flip(frame, -1)
|
||||
|
||||
result = classify_single_image(frame, model_path, roi_file)
|
||||
class_id = result["class_id"]
|
||||
score = result["score"]
|
||||
|
||||
print(f"[FRAME] {result['class']} conf={score}")
|
||||
|
||||
stable = judge.update(class_id)
|
||||
|
||||
save_flag = False
|
||||
save_class = class_id
|
||||
|
||||
if save_mode == "all":
|
||||
save_flag = True
|
||||
elif save_mode == "stable" and stable is not None:
|
||||
save_flag = True
|
||||
save_class = stable
|
||||
|
||||
if save_flag:
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{ts}_conf{score:.2f}.jpg"
|
||||
save_dir = os.path.join(save_root, f"class{save_class}")
|
||||
cv2.imwrite(os.path.join(save_dir, filename), frame)
|
||||
print(f"[SAVE] class{save_class}/{filename}")
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
# =====================================================
|
||||
# main
|
||||
# =====================================================
|
||||
if __name__ == "__main__":
|
||||
model_path = "muju_cls.rknn"
|
||||
roi_file = "./roi_coordinates/muju_roi.txt"
|
||||
|
||||
rtsp_url = "rtsp://admin:XJ123456@192.168.250.61:554/streaming/channels/101"
|
||||
|
||||
run_rtsp_classification_and_save(
|
||||
model_path=model_path,
|
||||
roi_file=roi_file,
|
||||
rtsp_url=rtsp_url,
|
||||
save_root="clsimg",
|
||||
stable_frames=3,
|
||||
save_mode="all" # 改成 "stable" 只存稳定结果
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user