Files
wood_vis/wood_ng/wood_ng.py
2026-02-27 15:08:35 +08:00

208 lines
5.2 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
ROI RKNN 图像分类模块
基于 RKNNLite 对输入图像的指定 ROI 区域进行分类,
输出类别 ID0: 异常1: 正常)。
支持:
- RKNN 模型单例加载
- ROI 裁剪与缩放
- BGR → RGB 预处理
- 主程序测试入口
"""
import os
from typing import Dict,Union
import cv2
import numpy as np
from rknnlite.api import RKNNLite
# =====================================================
# 全局配置(常量)
# =====================================================
RKNN_MODEL_PATH: str = "wood_ng_cls.rknn"
# ROI 坐标x1, y1, x2, y2像素坐标
ROI: tuple[int, int, int, int] = (3, 0, 694, 182)
CLASS_NAMES: Dict[int, str] = {
0: "异常",
1: "正常",
}
# =====================================================
# 全局 RKNN 实例(单例)
# =====================================================
_global_rknn: Union[RKNNLite, None] = None
def _init_rknn_model(model_path: str) -> RKNNLite:
"""
初始化并返回 RKNN 模型(单例模式)。
Args:
model_path (str): RKNN 模型路径
Returns:
RKNNLite: 已初始化的 RKNNLite 实例
Raises:
FileNotFoundError: 模型文件不存在
RuntimeError: RKNN 加载或运行时初始化失败
"""
global _global_rknn
if _global_rknn is not None:
return _global_rknn
if not os.path.exists(model_path):
raise FileNotFoundError(f"RKNN 模型不存在: {model_path}")
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功: {model_path}")
return rknn
# =====================================================
# 预处理函数
# =====================================================
def _preprocess_input(img: np.ndarray) -> np.ndarray:
"""
对 ROI 图像进行模型输入预处理。
Args:
img (np.ndarray): BGR 格式图像shape=(640, 640, 3)
Returns:
np.ndarray: NHWC 格式 float32 输入张量shape=(1, 640, 640, 3)
Raises:
ValueError: 输入图像尺寸不符合要求
"""
if img.shape[:2] != (640, 640):
raise ValueError("输入图像必须是 640x640")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_tensor = np.expand_dims(img_rgb.astype(np.float32), axis=0)
return np.ascontiguousarray(input_tensor)
# =====================================================
# ROI + RKNN 分类器
# =====================================================
class ROIClassifierRKNN:
"""
基于 RKNN 的 ROI 区域分类器。
功能:
- 加载 RKNN 分类模型
- 从原始图像中裁剪 ROI
- 执行分类推理并返回类别 ID
"""
def __init__(self, model_path: str) -> None:
"""
初始化分类器。
Args:
model_path (str): RKNN 模型路径
"""
self.rknn = _init_rknn_model(model_path)
def classify(self, img_np: np.ndarray) -> int:
"""
对输入图像进行 ROI 分类。
Args:
img_np (np.ndarray): 原始 BGR 图像
Returns:
int: 分类结果0: 异常1: 正常)
Raises:
ValueError: ROI 坐标非法
"""
height, width = img_np.shape[:2]
x1, y1, x2, y2 = ROI
# -------- ROI 边界保护 --------
x1 = max(0, min(x1, width - 1))
x2 = max(0, min(x2, width))
y1 = max(0, min(y1, height - 1))
y2 = max(0, min(y2, height))
if x2 <= x1 or y2 <= y1:
raise ValueError(f"ROI 坐标无效: {(x1, y1, x2, y2)}")
# -------- 1. 裁剪 ROI --------
roi_img = img_np[y1:y2, x1:x2]
# -------- 2. resize 到 640×640 --------
roi_img = cv2.resize(roi_img, (640, 640))
# -------- 3. 预处理 --------
input_tensor = _preprocess_input(roi_img)
# -------- 4. RKNN 推理 --------
outputs = self.rknn.inference([input_tensor])
logits = outputs[0].reshape(-1).astype(float)
return int(np.argmax(logits))
# =====================================================
# 对外接口
# =====================================================
_classifier = ROIClassifierRKNN(RKNN_MODEL_PATH)
def classify_wood_ng(img_np: np.ndarray) -> int:
"""
线条是否为ng料图像分类接口函数。
Args:
img_np (np.ndarray): 原始 BGR 图像
Returns:
int: 分类结果0 / 1
"""
return _classifier.classify(img_np)
# =====================================================
# 测试入口
# =====================================================
if __name__ == "__main__":
img_path = "1.png"
if not os.path.exists(img_path):
raise FileNotFoundError(f"图片不存在: {img_path}")
img = cv2.imread(img_path)
if img is None:
raise ValueError("图像加载失败")
result = classify_wood_ng(img)
print(
f"NG料结果{result} "
f"({CLASS_NAMES[result]})"
)