208 lines
5.2 KiB
Python
208 lines
5.2 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
ROI RKNN 图像分类模块
|
|||
|
|
|
|||
|
|
基于 RKNNLite 对输入图像的指定 ROI 区域进行分类,
|
|||
|
|
输出类别 ID(0: 异常,1: 正常)。
|
|||
|
|
|
|||
|
|
支持:
|
|||
|
|
- RKNN 模型单例加载
|
|||
|
|
- ROI 裁剪与缩放
|
|||
|
|
- BGR → RGB 预处理
|
|||
|
|
- 主程序测试入口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
from typing import Dict
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from rknnlite.api import RKNNLite
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
# 全局配置(常量)
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
RKNN_MODEL_PATH: str = "wood_ng_cls.rknn"
|
|||
|
|
|
|||
|
|
# ROI 坐标:x1, y1, x2, y2(像素坐标)
|
|||
|
|
ROI: tuple[int, int, int, int] = (3, 0, 694, 182)
|
|||
|
|
|
|||
|
|
CLASS_NAMES: Dict[int, str] = {
|
|||
|
|
0: "异常",
|
|||
|
|
1: "正常",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
# 全局 RKNN 实例(单例)
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
_global_rknn: RKNNLite | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _init_rknn_model(model_path: str) -> RKNNLite:
|
|||
|
|
"""
|
|||
|
|
初始化并返回 RKNN 模型(单例模式)。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model_path (str): RKNN 模型路径
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
RKNNLite: 已初始化的 RKNNLite 实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
FileNotFoundError: 模型文件不存在
|
|||
|
|
RuntimeError: RKNN 加载或运行时初始化失败
|
|||
|
|
"""
|
|||
|
|
global _global_rknn
|
|||
|
|
|
|||
|
|
if _global_rknn is not None:
|
|||
|
|
return _global_rknn
|
|||
|
|
|
|||
|
|
if not os.path.exists(model_path):
|
|||
|
|
raise FileNotFoundError(f"RKNN 模型不存在: {model_path}")
|
|||
|
|
|
|||
|
|
rknn = RKNNLite(verbose=False)
|
|||
|
|
|
|||
|
|
ret = rknn.load_rknn(model_path)
|
|||
|
|
if ret != 0:
|
|||
|
|
raise RuntimeError(f"Load RKNN failed: {ret}")
|
|||
|
|
|
|||
|
|
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
|||
|
|
if ret != 0:
|
|||
|
|
raise RuntimeError(f"Init runtime failed: {ret}")
|
|||
|
|
|
|||
|
|
_global_rknn = rknn
|
|||
|
|
print(f"[INFO] RKNN 模型加载成功: {model_path}")
|
|||
|
|
|
|||
|
|
return rknn
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
# 预处理函数
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
def _preprocess_input(img: np.ndarray) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
对 ROI 图像进行模型输入预处理。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img (np.ndarray): BGR 格式图像,shape=(640, 640, 3)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
np.ndarray: NHWC 格式 float32 输入张量,shape=(1, 640, 640, 3)
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ValueError: 输入图像尺寸不符合要求
|
|||
|
|
"""
|
|||
|
|
if img.shape[:2] != (640, 640):
|
|||
|
|
raise ValueError("输入图像必须是 640x640")
|
|||
|
|
|
|||
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|||
|
|
input_tensor = np.expand_dims(img_rgb.astype(np.float32), axis=0)
|
|||
|
|
|
|||
|
|
return np.ascontiguousarray(input_tensor)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
# ROI + RKNN 分类器
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
class ROIClassifierRKNN:
|
|||
|
|
"""
|
|||
|
|
基于 RKNN 的 ROI 区域分类器。
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
- 加载 RKNN 分类模型
|
|||
|
|
- 从原始图像中裁剪 ROI
|
|||
|
|
- 执行分类推理并返回类别 ID
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, model_path: str) -> None:
|
|||
|
|
"""
|
|||
|
|
初始化分类器。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model_path (str): RKNN 模型路径
|
|||
|
|
"""
|
|||
|
|
self.rknn = _init_rknn_model(model_path)
|
|||
|
|
|
|||
|
|
def classify(self, img_np: np.ndarray) -> int:
|
|||
|
|
"""
|
|||
|
|
对输入图像进行 ROI 分类。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img_np (np.ndarray): 原始 BGR 图像
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
int: 分类结果(0: 异常,1: 正常)
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ValueError: ROI 坐标非法
|
|||
|
|
"""
|
|||
|
|
height, width = img_np.shape[:2]
|
|||
|
|
x1, y1, x2, y2 = ROI
|
|||
|
|
|
|||
|
|
# -------- ROI 边界保护 --------
|
|||
|
|
x1 = max(0, min(x1, width - 1))
|
|||
|
|
x2 = max(0, min(x2, width))
|
|||
|
|
y1 = max(0, min(y1, height - 1))
|
|||
|
|
y2 = max(0, min(y2, height))
|
|||
|
|
|
|||
|
|
if x2 <= x1 or y2 <= y1:
|
|||
|
|
raise ValueError(f"ROI 坐标无效: {(x1, y1, x2, y2)}")
|
|||
|
|
|
|||
|
|
# -------- 1. 裁剪 ROI --------
|
|||
|
|
roi_img = img_np[y1:y2, x1:x2]
|
|||
|
|
|
|||
|
|
# -------- 2. resize 到 640×640 --------
|
|||
|
|
roi_img = cv2.resize(roi_img, (640, 640))
|
|||
|
|
|
|||
|
|
# -------- 3. 预处理 --------
|
|||
|
|
input_tensor = _preprocess_input(roi_img)
|
|||
|
|
|
|||
|
|
# -------- 4. RKNN 推理 --------
|
|||
|
|
outputs = self.rknn.inference([input_tensor])
|
|||
|
|
logits = outputs[0].reshape(-1).astype(float)
|
|||
|
|
|
|||
|
|
return int(np.argmax(logits))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
# 对外接口
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
_classifier = ROIClassifierRKNN(RKNN_MODEL_PATH)
|
|||
|
|
|
|||
|
|
def classify_wood_ng(img_np: np.ndarray) -> int:
|
|||
|
|
"""
|
|||
|
|
线条是否为ng料图像分类接口函数。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img_np (np.ndarray): 原始 BGR 图像
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
int: 分类结果(0 / 1)
|
|||
|
|
"""
|
|||
|
|
return _classifier.classify(img_np)
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
# 测试入口
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
img_path = "1.png"
|
|||
|
|
|
|||
|
|
if not os.path.exists(img_path):
|
|||
|
|
raise FileNotFoundError(f"图片不存在: {img_path}")
|
|||
|
|
|
|||
|
|
img = cv2.imread(img_path)
|
|||
|
|
if img is None:
|
|||
|
|
raise ValueError("图像加载失败")
|
|||
|
|
|
|||
|
|
result = classify_wood_ng(img)
|
|||
|
|
print(
|
|||
|
|
f"NG料结果:{result} "
|
|||
|
|
f"({CLASS_NAMES[result]})"
|
|||
|
|
)
|