Files
Feeding_control_system/vision/overflow_model/yiliao_main_rknn.py
2025-11-17 00:05:40 +08:00

186 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
import platform
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
# =====================================================
# RKNN MODEL
# =====================================================
def init_rknn_model(model_path):
from rknnlite.api import RKNNLite
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功: {model_path}")
return rknn
# ---------------------------
# 图像预处理(统一 640×640
# ---------------------------
def preprocess(img, size=(640, 640)):
img = cv2.resize(img, size)
img = np.expand_dims(img, 0)
return img
# ---------------------------
# 单次 RKNN 分类
# ---------------------------
def rknn_classify(img_resized, model_path):
rknn = init_rknn_model(model_path)
input_tensor = preprocess(img_resized)
outs = rknn.inference([input_tensor])
pred = outs[0].reshape(-1)
class_id = int(np.argmax(pred))
return class_id, pred.astype(float)
# =====================================================
# ROI 逻辑
# =====================================================
def load_rois(txt_path):
rois = []
if not os.path.exists(txt_path):
print(f"❌ ROI 文件不存在: {txt_path}")
return rois
with open(txt_path) as f:
for line in f:
s = line.strip()
if s:
try:
x, y, w, h = map(int, s.split(','))
rois.append((x, y, w, h))
except:
print("ROI 格式错误:", s)
return rois
def crop_and_resize(img, rois, target_size=640):
crops = []
h_img, w_img = img.shape[:2]
for idx, (x, y, w, h) in enumerate(rois):
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
continue
roi = img[y:y + h, x:x + w]
roi_resized = cv2.resize(roi, (target_size, target_size), interpolation=cv2.INTER_AREA)
crops.append((roi_resized, idx))
return crops
# =====================================================
# class1/class2 加权分类增强
# =====================================================
def weighted_small_large(pred, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred[1])
p2 = float(pred[2])
total = p1 + p2
score = (w1 * p1 + w2 * p2) / total if total > 0 else 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# =====================================================
# ⭐ 高复用:一行完成 ROI + 推理 ⭐
# =====================================================
def classify_frame_with_rois(model_path, frame, roi_file, threshold=0.4):
"""
输入:
- frame: BGR 图像 (numpy array)
- model_path: RKNN 模型路径
- roi_file: ROI 的 txt 文件
- threshold: class1/class2 小/大堆料判断阈值
输出:
[
{ "roi": idx, "class": 类别, "score": 0.93, "p1": 0.22, "p2": 0.71 },
...
]
"""
if frame is None or not isinstance(frame, np.ndarray):
raise RuntimeError("❌ classify_frame_with_rois 传入的 frame 无效")
rois = load_rois(roi_file)
if not rois:
raise RuntimeError("ROI 文件为空")
crops = crop_and_resize(frame, rois)
results = []
for roi_img, idx in crops:
class_id, pred = rknn_classify(roi_img, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred, threshold)
else:
final_class = class_name
score = float(pred[class_id])
p1, p2 = float(pred[1]), float(pred[2])
results.append({
"roi": idx,
"class": final_class,
"score": round(score, 4),
"p1": round(p1, 4),
"p2": round(p2, 4)
})
return results
# =====================================================
# 示例调用
# =====================================================
if __name__ == "__main__":
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "yiliao_cls.rknn")
roi_file = "./roi_coordinates/1_rois.txt"
frame = cv2.imread("./test_image/2.jpg")
outputs = classify_frame_with_rois(model_path, frame, roi_file)
for res in outputs:
print(res)