Files
wood_vis/wood_detect/wood_detect.py
2026-02-27 15:08:35 +08:00

213 lines
6.6 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
木条检测模块基于RKNNLite + ROI + 单类别NMS
功能:
- 检测ROI区域内木条数量
- 输出实际木条数量
"""
import os
from typing import Optional
import cv2
import numpy as np
from rknnlite.api import RKNNLite
# =====================================================
# 常量配置
# =====================================================
RKNN_MODEL_PATH: str = "wood_detect.rknn"
ROI: tuple[int, int, int, int] = (1, 1, 1000, 1000) # ROI 坐标(x1, y1, x2, y2)
IMG_SIZE: tuple[int, int] = (640, 640) # 模型输入大小
OBJ_THRESH: float = 0.25 # 置信度阈值
NMS_THRESH: float = 0.45 # NMS阈值
CLASS_NAME: list[str] = ["bag"] # 单类别名称后续改成wood
# =====================================================
# 私有工具函数
# =====================================================
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""计算softmax概率"""
x = x - np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def _letterbox_resize(image: np.ndarray,
size: tuple[int, int],
bg_color: int = 114) -> tuple[np.ndarray, float, int, int]:
"""保持长宽比缩放并填充到目标尺寸"""
target_w, target_h = size
h, w = image.shape[:2]
scale = min(target_w / w, target_h / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(image, (new_w, new_h))
canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8)
dx = (target_w - new_w) // 2
dy = (target_h - new_h) // 2
canvas[dy:dy + new_h, dx:dx + new_w] = resized
return canvas, scale, dx, dy
def _nms(boxes: np.ndarray, scores: np.ndarray, thresh: float) -> list[int]:
"""非极大值抑制"""
x1, y1, x2, y2 = boxes.T
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep: list[int] = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter)
order = order[1:][iou <= thresh]
return keep
def _post_process(outputs: list[np.ndarray],
scale: float,
dx: int,
dy: int) -> Optional[np.ndarray]:
"""RKNN模型输出后处理解码坐标 + NMS返回有效boxes"""
boxes_list, scores_list = [], []
strides = [8, 16, 32]
for i, stride in enumerate(strides):
reg = outputs[i * 3 + 0][0]
cls = outputs[i * 3 + 1][0]
obj = outputs[i * 3 + 2][0]
num_classes, H, W = cls.shape
reg_max = reg.shape[0] // 4
grid_y, grid_x = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
grid_x = grid_x.astype(np.float32).ravel()
grid_y = grid_y.astype(np.float32).ravel()
cls_flat = cls.reshape(num_classes, -1).T
obj_flat = obj.ravel()
max_cls_score = np.max(cls_flat, axis=1)
scores = max_cls_score * obj_flat
valid_mask = scores >= OBJ_THRESH
if not np.any(valid_mask):
continue
valid_idx = np.where(valid_mask)[0]
scores_v = scores[valid_idx]
gx = grid_x[valid_idx]
gy = grid_y[valid_idx]
reg_valid = reg.reshape(4, reg_max, -1)[:, :, valid_idx]
reg_softmax = _softmax(reg_valid, axis=1)
acc = np.arange(reg_max, dtype=np.float32).reshape(1, -1, 1)
distance = np.sum(reg_softmax * acc, axis=1)
cx = (gx + 0.5) * stride
cy = (gy + 0.5) * stride
l, t, r, b = distance[0], distance[1], distance[2], distance[3]
x1 = cx - l * stride
y1 = cy - t * stride
x2 = cx + r * stride
y2 = cy + b * stride
boxes = np.stack([x1, y1, x2, y2], axis=1)
boxes[:, [0, 2]] = (boxes[:, [0, 2]] - dx) / scale
boxes[:, [1, 3]] = (boxes[:, [1, 3]] - dy) / scale
boxes_list.append(boxes)
scores_list.append(scores_v)
if not boxes_list:
return None
boxes_all = np.concatenate(boxes_list, axis=0)
scores_all = np.concatenate(scores_list, axis=0)
keep_idx = _nms(boxes_all, scores_all, NMS_THRESH)
return boxes_all[keep_idx]
# =====================================================
# RKNN模型全局初始化
# =====================================================
_global_rknn = RKNNLite()
_global_rknn.load_rknn(RKNN_MODEL_PATH)
_global_rknn.init_runtime()
# =====================================================
# 木条检测类
# =====================================================
class WoodDetectorRKNN:
"""基于RKNN的木条检测器"""
def __init__(self):
"""初始化木条检测器"""
self.rknn = _global_rknn
def detect(self, img_np: np.ndarray) -> int:
"""
检测木条数量
Args:
img_np (np.ndarray): 原始BGR图像
Returns:
int: 检测到的木条数量
"""
# ROI裁剪
h, w = img_np.shape[:2]
x1, y1, x2, y2 = ROI
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h))
roi_img = img_np[y1:y2, x1:x2]
# letterbox resize
resized_img, scale, dx, dy = _letterbox_resize(roi_img, IMG_SIZE)
# 推理
input_tensor = np.expand_dims(resized_img.astype(np.float32), axis=0)
outputs = self.rknn.inference([input_tensor])
boxes = _post_process(outputs, scale, dx, dy)
if boxes is None:
return 0
return len(boxes)
# =====================================================
# 对外统一接口
# =====================================================
_detector = WoodDetectorRKNN()
def detect_wood(img_np: np.ndarray) -> int:
"""
对外木条检测接口(返回木条数量)
Args:
img_np (np.ndarray): 原始BGR图像
Returns:
int: 检测到的木条数量
"""
return _detector.detect(img_np)
# =====================================================
# 主程序入口
# =====================================================
if __name__ == "__main__":
img_path = "1.jpg"
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img = cv2.imread(img_path)
if img is None:
raise ValueError("图像加载失败")
total_count = detect_wood(img)
print(f"检测到木条数量: {total_count}")