2025-09-01 14:14:18 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
import numpy as np
|
2025-09-05 14:29:33 +08:00
|
|
|
|
from rknnlite.api import RKNNLite
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- 参数 -------------------
|
|
|
|
|
|
objectThresh = 0.7 # 置信度阈值,可调整
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- 工具函数 -------------------
|
|
|
|
|
|
def letterbox_resize(image, size, bg_color=114):
|
|
|
|
|
|
target_width, target_height = size
|
|
|
|
|
|
h, w, _ = image.shape
|
|
|
|
|
|
scale = min(target_width / w, target_height / h)
|
|
|
|
|
|
new_w, new_h = int(w * scale), int(h * scale)
|
|
|
|
|
|
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * bg_color
|
|
|
|
|
|
offset_x, offset_y = (target_width - new_w) // 2, (target_height - new_h) // 2
|
|
|
|
|
|
canvas[offset_y:offset_y+new_h, offset_x:offset_x+new_w] = resized
|
|
|
|
|
|
return canvas, scale, offset_x, offset_y
|
2025-09-01 14:14:18 +08:00
|
|
|
|
|
|
|
|
|
|
def sigmoid(x):
|
|
|
|
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
# ------------------- 单目标分割函数 -------------------
|
|
|
|
|
|
def detect_single_mask(model_path, image_path):
|
|
|
|
|
|
img = cv2.imread(image_path)
|
|
|
|
|
|
if img is None:
|
|
|
|
|
|
print(f"❌ 错误:无法读取图像 {image_path}")
|
2025-09-01 14:14:18 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
img_resized, scale, offset_x, offset_y = letterbox_resize(img, (640, 640))
|
|
|
|
|
|
infer_img = img_resized[..., ::-1] # BGR->RGB
|
|
|
|
|
|
infer_img = np.expand_dims(infer_img, 0)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- RKNN 推理 -------------------
|
|
|
|
|
|
rknn = RKNNLite(verbose=True)
|
|
|
|
|
|
print('--> Load RKNN model')
|
|
|
|
|
|
rknn.load_rknn(model_path)
|
|
|
|
|
|
print('done')
|
|
|
|
|
|
|
|
|
|
|
|
print('--> Init runtime environment')
|
|
|
|
|
|
rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
|
|
|
|
|
print('done')
|
|
|
|
|
|
|
|
|
|
|
|
print('--> Running model')
|
|
|
|
|
|
outputs = rknn.inference([infer_img])
|
|
|
|
|
|
rknn.release()
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- 处理输出 -------------------
|
|
|
|
|
|
# outputs顺序参考你提供:
|
|
|
|
|
|
# [output1, output2, output3, output4, ... , output12, output13]
|
|
|
|
|
|
# 我们只取置信度最高点的 mask
|
|
|
|
|
|
|
|
|
|
|
|
mask_coeffs_list = [outputs[3], outputs[7], outputs[11]] # mask coefficients
|
|
|
|
|
|
conf_list = [outputs[1], outputs[5], outputs[9]] # object confidence
|
|
|
|
|
|
proto = outputs[12][0] # proto [32,160,160]
|
|
|
|
|
|
|
|
|
|
|
|
# 找到所有尺度的最大置信度位置
|
|
|
|
|
|
best_idx = None
|
|
|
|
|
|
best_conf = -1
|
|
|
|
|
|
best_scale = None
|
|
|
|
|
|
for i, conf_map in enumerate(conf_list):
|
|
|
|
|
|
conf_map_flat = conf_map.flatten()
|
|
|
|
|
|
idx = np.argmax(conf_map_flat)
|
|
|
|
|
|
if conf_map_flat[idx] > best_conf:
|
|
|
|
|
|
best_conf = conf_map_flat[idx]
|
|
|
|
|
|
best_idx = idx
|
|
|
|
|
|
best_scale = i
|
|
|
|
|
|
|
|
|
|
|
|
if best_conf < objectThresh:
|
|
|
|
|
|
print(f"⚠️ 置信度低于阈值 {objectThresh},未检测到目标")
|
|
|
|
|
|
return None
|
2025-09-01 14:14:18 +08:00
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
# ------------------- 构建 mask -------------------
|
|
|
|
|
|
coeff = mask_coeffs_list[best_scale].reshape(mask_coeffs_list[best_scale].shape[1], -1)
|
|
|
|
|
|
mask_flat = np.matmul(coeff[:, best_idx], proto.reshape(proto.shape[0], -1))
|
|
|
|
|
|
mask = sigmoid(mask_flat).reshape(proto.shape[1], proto.shape[2])
|
2025-09-01 14:14:18 +08:00
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
# resize 回原图
|
|
|
|
|
|
mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]))
|
|
|
|
|
|
mask_bin = (mask_resized > 0.5).astype(np.uint8) * 255
|
2025-09-01 14:14:18 +08:00
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
# 保存或显示
|
|
|
|
|
|
cv2.imwrite("mask_result.png", mask_bin)
|
|
|
|
|
|
print("✅ 单目标 mask 已保存: mask_result.png")
|
2025-09-01 14:14:18 +08:00
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
return mask_bin
|
2025-09-01 14:14:18 +08:00
|
|
|
|
|
2025-09-05 14:29:33 +08:00
|
|
|
|
# ------------------- 调用示例 -------------------
|
2025-09-01 14:14:18 +08:00
|
|
|
|
if __name__ == "__main__":
|
2025-09-05 14:29:33 +08:00
|
|
|
|
model_path = "seg.rknn"
|
|
|
|
|
|
image_path = "2.jpg"
|
|
|
|
|
|
detect_single_mask(model_path, image_path)
|