Files
wood_vis/wood_detect/wood_detect.py

213 lines
6.6 KiB
Python
Raw Permalink Normal View History

2026-02-25 14:24:05 +08:00
# -*- coding: utf-8 -*-
"""
木条检测模块基于RKNNLite + ROI + 单类别NMS
功能
- 检测ROI区域内木条数量
- 输出实际木条数量
"""
import os
from typing import Optional
import cv2
import numpy as np
from rknnlite.api import RKNNLite
# =====================================================
# 常量配置
# =====================================================
RKNN_MODEL_PATH: str = "wood_detect.rknn"
ROI: tuple[int, int, int, int] = (1, 1, 1000, 1000) # ROI 坐标(x1, y1, x2, y2)
IMG_SIZE: tuple[int, int] = (640, 640) # 模型输入大小
OBJ_THRESH: float = 0.25 # 置信度阈值
NMS_THRESH: float = 0.45 # NMS阈值
CLASS_NAME: list[str] = ["bag"] # 单类别名称后续改成wood
# =====================================================
# 私有工具函数
# =====================================================
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""计算softmax概率"""
x = x - np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def _letterbox_resize(image: np.ndarray,
size: tuple[int, int],
bg_color: int = 114) -> tuple[np.ndarray, float, int, int]:
"""保持长宽比缩放并填充到目标尺寸"""
target_w, target_h = size
h, w = image.shape[:2]
scale = min(target_w / w, target_h / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(image, (new_w, new_h))
canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8)
dx = (target_w - new_w) // 2
dy = (target_h - new_h) // 2
canvas[dy:dy + new_h, dx:dx + new_w] = resized
return canvas, scale, dx, dy
def _nms(boxes: np.ndarray, scores: np.ndarray, thresh: float) -> list[int]:
"""非极大值抑制"""
x1, y1, x2, y2 = boxes.T
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep: list[int] = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter)
order = order[1:][iou <= thresh]
return keep
def _post_process(outputs: list[np.ndarray],
scale: float,
dx: int,
dy: int) -> Optional[np.ndarray]:
"""RKNN模型输出后处理解码坐标 + NMS返回有效boxes"""
boxes_list, scores_list = [], []
strides = [8, 16, 32]
for i, stride in enumerate(strides):
reg = outputs[i * 3 + 0][0]
cls = outputs[i * 3 + 1][0]
obj = outputs[i * 3 + 2][0]
num_classes, H, W = cls.shape
reg_max = reg.shape[0] // 4
grid_y, grid_x = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
grid_x = grid_x.astype(np.float32).ravel()
grid_y = grid_y.astype(np.float32).ravel()
cls_flat = cls.reshape(num_classes, -1).T
obj_flat = obj.ravel()
max_cls_score = np.max(cls_flat, axis=1)
scores = max_cls_score * obj_flat
valid_mask = scores >= OBJ_THRESH
if not np.any(valid_mask):
continue
valid_idx = np.where(valid_mask)[0]
scores_v = scores[valid_idx]
gx = grid_x[valid_idx]
gy = grid_y[valid_idx]
reg_valid = reg.reshape(4, reg_max, -1)[:, :, valid_idx]
reg_softmax = _softmax(reg_valid, axis=1)
acc = np.arange(reg_max, dtype=np.float32).reshape(1, -1, 1)
distance = np.sum(reg_softmax * acc, axis=1)
cx = (gx + 0.5) * stride
cy = (gy + 0.5) * stride
l, t, r, b = distance[0], distance[1], distance[2], distance[3]
x1 = cx - l * stride
y1 = cy - t * stride
x2 = cx + r * stride
y2 = cy + b * stride
boxes = np.stack([x1, y1, x2, y2], axis=1)
boxes[:, [0, 2]] = (boxes[:, [0, 2]] - dx) / scale
boxes[:, [1, 3]] = (boxes[:, [1, 3]] - dy) / scale
boxes_list.append(boxes)
scores_list.append(scores_v)
if not boxes_list:
return None
boxes_all = np.concatenate(boxes_list, axis=0)
scores_all = np.concatenate(scores_list, axis=0)
keep_idx = _nms(boxes_all, scores_all, NMS_THRESH)
return boxes_all[keep_idx]
# =====================================================
# RKNN模型全局初始化
# =====================================================
_global_rknn = RKNNLite()
_global_rknn.load_rknn(RKNN_MODEL_PATH)
_global_rknn.init_runtime()
# =====================================================
# 木条检测类
# =====================================================
class WoodDetectorRKNN:
"""基于RKNN的木条检测器"""
def __init__(self):
"""初始化木条检测器"""
self.rknn = _global_rknn
def detect(self, img_np: np.ndarray) -> int:
"""
检测木条数量
Args:
img_np (np.ndarray): 原始BGR图像
Returns:
int: 检测到的木条数量
"""
# ROI裁剪
h, w = img_np.shape[:2]
x1, y1, x2, y2 = ROI
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h))
roi_img = img_np[y1:y2, x1:x2]
# letterbox resize
resized_img, scale, dx, dy = _letterbox_resize(roi_img, IMG_SIZE)
# 推理
input_tensor = np.expand_dims(resized_img.astype(np.float32), axis=0)
outputs = self.rknn.inference([input_tensor])
boxes = _post_process(outputs, scale, dx, dy)
if boxes is None:
return 0
return len(boxes)
# =====================================================
# 对外统一接口
# =====================================================
_detector = WoodDetectorRKNN()
def detect_wood(img_np: np.ndarray) -> int:
"""
对外木条检测接口返回木条数量
Args:
img_np (np.ndarray): 原始BGR图像
Returns:
int: 检测到的木条数量
"""
return _detector.detect(img_np)
# =====================================================
# 主程序入口
# =====================================================
if __name__ == "__main__":
img_path = "1.jpg"
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img = cv2.imread(img_path)
if img is None:
raise ValueError("图像加载失败")
total_count = detect_wood(img)
print(f"检测到木条数量: {total_count}")