Files
zjsh_yolov11/angle_base_seg/bushu.py

94 lines
3.3 KiB
Python
Raw Normal View History

2025-09-01 14:14:18 +08:00
import os
import cv2
import numpy as np
2025-09-05 14:29:33 +08:00
from rknnlite.api import RKNNLite
# ------------------- 参数 -------------------
objectThresh = 0.7 # 置信度阈值,可调整
# ------------------- 工具函数 -------------------
def letterbox_resize(image, size, bg_color=114):
target_width, target_height = size
h, w, _ = image.shape
scale = min(target_width / w, target_height / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * bg_color
offset_x, offset_y = (target_width - new_w) // 2, (target_height - new_h) // 2
canvas[offset_y:offset_y+new_h, offset_x:offset_x+new_w] = resized
return canvas, scale, offset_x, offset_y
2025-09-01 14:14:18 +08:00
def sigmoid(x):
return 1 / (1 + np.exp(-x))
2025-09-05 14:29:33 +08:00
# ------------------- 单目标分割函数 -------------------
def detect_single_mask(model_path, image_path):
img = cv2.imread(image_path)
if img is None:
print(f"❌ 错误:无法读取图像 {image_path}")
2025-09-01 14:14:18 +08:00
return None
2025-09-05 14:29:33 +08:00
img_resized, scale, offset_x, offset_y = letterbox_resize(img, (640, 640))
infer_img = img_resized[..., ::-1] # BGR->RGB
infer_img = np.expand_dims(infer_img, 0)
# ------------------- RKNN 推理 -------------------
rknn = RKNNLite(verbose=True)
print('--> Load RKNN model')
rknn.load_rknn(model_path)
print('done')
print('--> Init runtime environment')
rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
print('done')
print('--> Running model')
outputs = rknn.inference([infer_img])
rknn.release()
# ------------------- 处理输出 -------------------
# outputs顺序参考你提供
# [output1, output2, output3, output4, ... , output12, output13]
# 我们只取置信度最高点的 mask
mask_coeffs_list = [outputs[3], outputs[7], outputs[11]] # mask coefficients
conf_list = [outputs[1], outputs[5], outputs[9]] # object confidence
proto = outputs[12][0] # proto [32,160,160]
# 找到所有尺度的最大置信度位置
best_idx = None
best_conf = -1
best_scale = None
for i, conf_map in enumerate(conf_list):
conf_map_flat = conf_map.flatten()
idx = np.argmax(conf_map_flat)
if conf_map_flat[idx] > best_conf:
best_conf = conf_map_flat[idx]
best_idx = idx
best_scale = i
if best_conf < objectThresh:
print(f"⚠️ 置信度低于阈值 {objectThresh},未检测到目标")
return None
2025-09-01 14:14:18 +08:00
2025-09-05 14:29:33 +08:00
# ------------------- 构建 mask -------------------
coeff = mask_coeffs_list[best_scale].reshape(mask_coeffs_list[best_scale].shape[1], -1)
mask_flat = np.matmul(coeff[:, best_idx], proto.reshape(proto.shape[0], -1))
mask = sigmoid(mask_flat).reshape(proto.shape[1], proto.shape[2])
2025-09-01 14:14:18 +08:00
2025-09-05 14:29:33 +08:00
# resize 回原图
mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]))
mask_bin = (mask_resized > 0.5).astype(np.uint8) * 255
2025-09-01 14:14:18 +08:00
2025-09-05 14:29:33 +08:00
# 保存或显示
cv2.imwrite("mask_result.png", mask_bin)
print("✅ 单目标 mask 已保存: mask_result.png")
2025-09-01 14:14:18 +08:00
2025-09-05 14:29:33 +08:00
return mask_bin
2025-09-01 14:14:18 +08:00
2025-09-05 14:29:33 +08:00
# ------------------- 调用示例 -------------------
2025-09-01 14:14:18 +08:00
if __name__ == "__main__":
2025-09-05 14:29:33 +08:00
model_path = "seg.rknn"
image_path = "2.jpg"
detect_single_mask(model_path, image_path)