Files
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

143 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from rknnlite.api import RKNNLite
# ======================
# 配置
# ======================
IMAGE_PATH = "3.png"
MODEL_PATH = "segr.rknn"
OUT_OVERLAY = "result_overlay.jpg"
DEBUG_INPUT = "debug_input_roi.png"
DEBUG_PROTO = "debug_proto_mask.png"
DEBUG_INST_PROTO = "debug_inst_proto.png"
IMG_SIZE = 640
OBJ_THRESH = 0.25
MASK_THRESH = 0.5
STRIDES = [8, 16, 32]
ROIS = [
(670, 623, 465, 178),
]
# ======================
# 工具函数
# ======================
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def resize_to_640(img):
"""seg 专用:禁止 letterbox直接 resize"""
return cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
def dfl_decode(dfl):
bins = np.arange(16)
dfl = sigmoid(dfl)
dfl /= np.sum(dfl, axis=1, keepdims=True)
return np.sum(dfl * bins, axis=1)
def largest_cc(mask):
num, labels = cv2.connectedComponents(mask.astype(np.uint8))
if num <= 1:
return mask
areas = [(labels == i).sum() for i in range(1, num)]
return (labels == (np.argmax(areas) + 1)).astype(np.uint8)
# ======================
# 单 ROI 推理(完整语义 mask
# ======================
def infer_single_roi(rknn, roi):
h0, w0 = roi.shape[:2]
# ---------- 1⃣ 正确的 seg 输入 ----------
inp_img = resize_to_640(roi)
cv2.imwrite(DEBUG_INPUT, inp_img)
inp = inp_img[..., ::-1][None, ...]
outputs = rknn.inference([inp])
# ---------- 2⃣ proto ----------
proto = outputs[12][0] # (32,160,160)
best_score = -1
best_coef = None
out_i = 0
for stride in STRIDES:
reg = outputs[out_i][0]
cls = outputs[out_i + 1][0, 0]
obj = outputs[out_i + 2][0, 0]
coef = outputs[out_i + 3][0]
out_i += 4
score_map = sigmoid(cls) * sigmoid(obj)
y, x = np.unravel_index(np.argmax(score_map), score_map.shape)
score = score_map[y, x]
if score < OBJ_THRESH or score <= best_score:
continue
best_score = score
best_coef = coef[:, y, x]
if best_coef is None:
return None
# ---------- 3⃣ proto_mask完整 ----------
proto_mask = sigmoid(np.tensordot(best_coef, proto, axes=1)) # (160,160)
pm = (proto_mask - proto_mask.min()) / (proto_mask.max() - proto_mask.min() + 1e-6)
cv2.imwrite(DEBUG_PROTO, (pm * 255).astype(np.uint8))
# ---------- 4⃣ 二值化 + 最大连通域(不裁!) ----------
inst_proto = (proto_mask > MASK_THRESH).astype(np.uint8)
inst_proto = largest_cc(inst_proto)
cv2.imwrite(DEBUG_INST_PROTO, inst_proto * 255)
# ---------- 5⃣ proto → ROI ----------
inst_roi = cv2.resize(
inst_proto, (w0, h0), interpolation=cv2.INTER_NEAREST
)
return inst_roi * 255
# ======================
# 主程序
# ======================
def main():
img = cv2.imread(IMAGE_PATH)
overlay = img.copy()
rknn = RKNNLite()
rknn.load_rknn(MODEL_PATH)
rknn.init_runtime()
for (x, y, w, h) in ROIS:
roi = img[y:y + h, x:x + w]
mask = infer_single_roi(rknn, roi)
if mask is None:
continue
color = np.zeros_like(roi)
color[mask == 255] = (0, 255, 0)
overlay[y:y + h, x:x + w] = cv2.addWeighted(
roi, 0.7, color, 0.3, 0
)
rknn.release()
cv2.imwrite(OUT_OVERLAY, overlay)
print("✅ 完成:", OUT_OVERLAY)
if __name__ == "__main__":
main()