Files
2025-09-05 14:29:33 +08:00

94 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from rknnlite.api import RKNNLite
# ------------------- 参数 -------------------
objectThresh = 0.7 # 置信度阈值,可调整
# ------------------- 工具函数 -------------------
def letterbox_resize(image, size, bg_color=114):
target_width, target_height = size
h, w, _ = image.shape
scale = min(target_width / w, target_height / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * bg_color
offset_x, offset_y = (target_width - new_w) // 2, (target_height - new_h) // 2
canvas[offset_y:offset_y+new_h, offset_x:offset_x+new_w] = resized
return canvas, scale, offset_x, offset_y
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# ------------------- 单目标分割函数 -------------------
def detect_single_mask(model_path, image_path):
img = cv2.imread(image_path)
if img is None:
print(f"❌ 错误:无法读取图像 {image_path}")
return None
img_resized, scale, offset_x, offset_y = letterbox_resize(img, (640, 640))
infer_img = img_resized[..., ::-1] # BGR->RGB
infer_img = np.expand_dims(infer_img, 0)
# ------------------- RKNN 推理 -------------------
rknn = RKNNLite(verbose=True)
print('--> Load RKNN model')
rknn.load_rknn(model_path)
print('done')
print('--> Init runtime environment')
rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
print('done')
print('--> Running model')
outputs = rknn.inference([infer_img])
rknn.release()
# ------------------- 处理输出 -------------------
# outputs顺序参考你提供
# [output1, output2, output3, output4, ... , output12, output13]
# 我们只取置信度最高点的 mask
mask_coeffs_list = [outputs[3], outputs[7], outputs[11]] # mask coefficients
conf_list = [outputs[1], outputs[5], outputs[9]] # object confidence
proto = outputs[12][0] # proto [32,160,160]
# 找到所有尺度的最大置信度位置
best_idx = None
best_conf = -1
best_scale = None
for i, conf_map in enumerate(conf_list):
conf_map_flat = conf_map.flatten()
idx = np.argmax(conf_map_flat)
if conf_map_flat[idx] > best_conf:
best_conf = conf_map_flat[idx]
best_idx = idx
best_scale = i
if best_conf < objectThresh:
print(f"⚠️ 置信度低于阈值 {objectThresh},未检测到目标")
return None
# ------------------- 构建 mask -------------------
coeff = mask_coeffs_list[best_scale].reshape(mask_coeffs_list[best_scale].shape[1], -1)
mask_flat = np.matmul(coeff[:, best_idx], proto.reshape(proto.shape[0], -1))
mask = sigmoid(mask_flat).reshape(proto.shape[1], proto.shape[2])
# resize 回原图
mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]))
mask_bin = (mask_resized > 0.5).astype(np.uint8) * 255
# 保存或显示
cv2.imwrite("mask_result.png", mask_bin)
print("✅ 单目标 mask 已保存: mask_result.png")
return mask_bin
# ------------------- 调用示例 -------------------
if __name__ == "__main__":
model_path = "seg.rknn"
image_path = "2.jpg"
detect_single_mask(model_path, image_path)