94 lines
3.6 KiB
Python
94 lines
3.6 KiB
Python
|
|
# yolo11_main.py
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from collections import deque
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 导入模块(不是函数)
|
|||
|
|
from .aligment_inference import yolov11_cls_inference
|
|||
|
|
|
|||
|
|
# 模型路径
|
|||
|
|
CLS_MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "yolov11_cls_640v6.rknn")
|
|||
|
|
|
|||
|
|
class ClassificationStabilizer:
|
|||
|
|
"""分类结果稳定性校验器,处理瞬时噪声帧"""
|
|||
|
|
|
|||
|
|
def __init__(self, window_size=5, switch_threshold=2):
|
|||
|
|
self.window_size = window_size # 滑动窗口大小(缓存最近N帧结果)
|
|||
|
|
self.switch_threshold = switch_threshold # 状态切换需要连续N帧一致
|
|||
|
|
self.result_buffer = deque(maxlen=window_size) # 缓存最近结果
|
|||
|
|
self.current_state = "盖板未对齐" # 初始状态
|
|||
|
|
self.consecutive_count = 0 # 当前状态连续出现的次数
|
|||
|
|
|
|||
|
|
def stabilize(self, current_frame_result):
|
|||
|
|
"""
|
|||
|
|
输入当前帧的分类结果,返回经过稳定性校验的结果
|
|||
|
|
Args:
|
|||
|
|
current_frame_result: 当前帧的原始分类结果(str)
|
|||
|
|
Returns:
|
|||
|
|
str: 经过校验的稳定结果
|
|||
|
|
"""
|
|||
|
|
# 1. 将当前帧结果加入滑动窗口
|
|||
|
|
self.result_buffer.append(current_frame_result)
|
|||
|
|
|
|||
|
|
# 2. 统计窗口内各结果的出现次数(多数投票基础)
|
|||
|
|
result_counts = {}
|
|||
|
|
for res in self.result_buffer:
|
|||
|
|
result_counts[res] = result_counts.get(res, 0) + 1 # 使用 result_counts 字典记录每个元素出现的总次数。
|
|||
|
|
|
|||
|
|
# 3. 找到窗口中出现次数最多的结果(候选结果)
|
|||
|
|
candidate = max(result_counts, key=result_counts.get)
|
|||
|
|
|
|||
|
|
# 4. 状态切换校验:只有候选结果连续出现N次才允许切换
|
|||
|
|
if candidate == self.current_state:
|
|||
|
|
# 与当前状态一致,重置连续计数
|
|||
|
|
self.consecutive_count = 0
|
|||
|
|
else:
|
|||
|
|
# 与当前状态不一致,累计连续次数
|
|||
|
|
self.consecutive_count += 1
|
|||
|
|
# 连续达到阈值,才更新状态
|
|||
|
|
if self.consecutive_count >= self.switch_threshold:
|
|||
|
|
self.current_state = candidate
|
|||
|
|
self.consecutive_count = 0
|
|||
|
|
|
|||
|
|
return self.current_state
|
|||
|
|
|
|||
|
|
# 初始化稳定性校验器(全局唯一实例,确保状态连续)
|
|||
|
|
cls_stabilizer = ClassificationStabilizer(
|
|||
|
|
window_size=5, # 缓存最近5帧
|
|||
|
|
switch_threshold=2 # 连续2帧一致才切换状态
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ====================== 分类接口(可选,保持原逻辑) ======================
|
|||
|
|
def run_yolo_classification(rgb_frame):
|
|||
|
|
"""
|
|||
|
|
YOLO 图像分类接口函数
|
|||
|
|
Args:
|
|||
|
|
rgb_frame: numpy array (H, W, 3), RGB 格式
|
|||
|
|
Returns:
|
|||
|
|
str: 分类结果("盖板对齐" / "盖板未对齐" / "异常")
|
|||
|
|
"""
|
|||
|
|
if not isinstance(rgb_frame, np.ndarray):
|
|||
|
|
print(f"[ERROR] 输入类型错误:需为 np.ndarray,当前为 {type(rgb_frame)}")
|
|||
|
|
return "异常"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
cover_cls = yolov11_cls_inference(CLS_MODEL_PATH, rgb_frame, target_size=(640, 640))
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARN] 分类推理失败: {e}")
|
|||
|
|
cover_cls = "异常"
|
|||
|
|
|
|||
|
|
raw_result = "盖板未对齐" # 默认值
|
|||
|
|
# 结果映射
|
|||
|
|
if cover_cls == "cover_ready":
|
|||
|
|
raw_result = "盖板对齐"
|
|||
|
|
elif cover_cls == "cover_noready":
|
|||
|
|
raw_result = "盖板未对齐"
|
|||
|
|
else:
|
|||
|
|
raw_result = "异常"
|
|||
|
|
# 通过稳定性校验器处理,返回最终结果
|
|||
|
|
stable_result = cls_stabilizer.stabilize(raw_result)
|
|||
|
|
print("raw_result, stable_result:",raw_result, stable_result)
|
|||
|
|
return stable_result
|
|||
|
|
|