Files
zjsh_code_jicheng/zhuangtai_class_cls_1980x1080/yiliao_main_rknn.py
琉璃月光 caeb0457f4 Initial commit
2025-11-18 17:16:08 +08:00

179 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import cv2
import numpy as np
import platform
from rknnlite.api import RKNNLite
# ---------------------------
# 类别映射
# ---------------------------
CLASS_NAMES = {
0: "未堆料",
1: "小堆料",
2: "大堆料",
3: "未浇筑满",
4: "浇筑满"
}
# ---------------------------
# RKNN 全局实例(只加载一次)
# ---------------------------
_global_rknn = None
DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
# =====================================================
# RKNN MODEL
# =====================================================
def init_rknn_model(model_path):
global _global_rknn
if _global_rknn is not None:
return _global_rknn
rknn = RKNNLite(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN failed: {ret}")
ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
raise RuntimeError(f"Init runtime failed: {ret}")
_global_rknn = rknn
print(f"[INFO] RKNN 模型加载成功: {model_path}")
return rknn
# ---------------------------
# 图像预处理(统一 640×640
# ---------------------------
def preprocess(img, size=(640, 640)):
img = cv2.resize(img, size)
img = np.expand_dims(img, 0)
return img
# ---------------------------
# 单次 RKNN 分类
# ---------------------------
def rknn_classify(img_resized, model_path):
rknn = init_rknn_model(model_path)
input_tensor = preprocess(img_resized)
outs = rknn.inference([input_tensor])
pred = outs[0].reshape(-1)
class_id = int(np.argmax(pred))
return class_id, pred.astype(float)
# =====================================================
# ROI 逻辑
# =====================================================
def load_single_roi(txt_path):
"""
只加载第一个 ROI
格式: x,y,w,h
"""
if not os.path.exists(txt_path):
raise RuntimeError(f"ROI 文件不存在: {txt_path}")
with open(txt_path) as f:
for line in f:
s = line.strip()
if not s:
continue
try:
x, y, w, h = map(int, s.split(','))
return (x, y, w, h)
except:
raise RuntimeError(f"❌ ROI 格式错误: {s}")
raise RuntimeError("❌ ROI 文件为空")
def crop_and_resize_single(img, roi, target_size=640):
x, y, w, h = roi
h_img, w_img = img.shape[:2]
if x < 0 or y < 0 or x + w > w_img or y + h > h_img:
raise RuntimeError(f"ROI 超出图像范围: {roi}")
roi_img = img[y:y + h, x:x + w]
roi_resized = cv2.resize(roi_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return roi_resized
# =====================================================
# class1/class2 加权分类增强
# =====================================================
def weighted_small_large(pred, threshold=0.4, w1=0.3, w2=0.7):
p1 = float(pred[1])
p2 = float(pred[2])
total = p1 + p2
score = (w1 * p1 + w2 * p2) / total if total > 0 else 0.0
final_class = "大堆料" if score >= threshold else "小堆料"
return final_class, score, p1, p2
# =====================================================
# 只处理一个 ROI
# =====================================================
def classify_frame_with_single_roi(model_path, frame, roi_file, threshold=0.4):
"""
输入:
- frame: BGR 图像
- model_path: RKNN 模型
- roi_file: 只包含一个 ROI 的 txt 文件
- threshold: class1/class2 判断阈值
输出:
{ "class": 类别, "score": x, "p1": x, "p2": x }
"""
if frame is None or not isinstance(frame, np.ndarray):
raise RuntimeError("❌ classify_frame_with_single_roi 传入的 frame 无效")
# ------- 只加载第一个 ROI -------
roi = load_single_roi(roi_file)
# ------- 裁剪并 resize -------
roi_img = crop_and_resize_single(frame, roi)
# ------- RKNN 推理 -------
class_id, pred = rknn_classify(roi_img, model_path)
class_name = CLASS_NAMES.get(class_id, f"未知类别({class_id})")
# ------- class1/class2 加权处理 -------
if class_id in [1, 2]:
final_class, score, p1, p2 = weighted_small_large(pred, threshold)
else:
final_class = class_name
score = float(pred[class_id])
p1, p2 = float(pred[1]), float(pred[2])
return {
"class": final_class,
"score": round(score, 4),
"p1": round(p1, 4),
"p2": round(p2, 4)
}
# =====================================================
# 示例调用
# =====================================================
if __name__ == "__main__":
model_path = "yiliao_cls.rknn"
roi_file = "./roi_coordinates/1_rois.txt"
frame = cv2.imread("./test_image/1.png")
result = classify_frame_with_single_roi(model_path, frame, roi_file)
print(result)