203 lines
5.4 KiB
Python
203 lines
5.4 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
from rknnlite.api import RKNNLite
|
|
|
|
# ====================== 配置 ======================
|
|
MODEL_PATH = "bag3568.rknn"
|
|
IMG_SIZE = (640, 640)
|
|
|
|
OBJ_THRESH = 0.25
|
|
NMS_THRESH = 0.45
|
|
|
|
CLASS_NAME = ["bag", "bag35"]
|
|
|
|
# ====================== 工具函数 ======================
|
|
def softmax(x, axis=-1):
|
|
x = x - np.max(x, axis=axis, keepdims=True)
|
|
exp_x = np.exp(x)
|
|
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
|
|
|
def letterbox_resize(image, size, bg_color=114):
|
|
target_w, target_h = size
|
|
h, w = image.shape[:2]
|
|
scale = min(target_w / w, target_h / h)
|
|
|
|
new_w, new_h = int(w * scale), int(h * scale)
|
|
resized = cv2.resize(image, (new_w, new_h))
|
|
|
|
canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8)
|
|
dx = (target_w - new_w) // 2
|
|
dy = (target_h - new_h) // 2
|
|
canvas[dy:dy + new_h, dx:dx + new_w] = resized
|
|
|
|
return canvas, scale, dx, dy
|
|
|
|
# ====================== DFL 解码 ======================
|
|
def dfl_decode(reg):
|
|
reg = reg.reshape(4, -1)
|
|
prob = softmax(reg, axis=1)
|
|
acc = np.arange(reg.shape[1])
|
|
return np.sum(prob * acc, axis=1)
|
|
|
|
# ====================== NMS ======================
|
|
def nms(boxes, scores, thresh):
|
|
boxes = np.array(boxes)
|
|
scores = np.array(scores)
|
|
|
|
x1, y1, x2, y2 = boxes.T
|
|
areas = (x2 - x1) * (y2 - y1)
|
|
order = scores.argsort()[::-1]
|
|
|
|
keep = []
|
|
while order.size > 0:
|
|
i = order[0]
|
|
keep.append(i)
|
|
|
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
|
|
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
|
|
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
|
|
|
order = order[1:][iou <= thresh]
|
|
|
|
return keep
|
|
|
|
# ====================== 后处理 ======================
|
|
def post_process(outputs, scale, dx, dy):
|
|
boxes_all, scores_all, classes_all = [], [], []
|
|
|
|
strides = [8, 16, 32]
|
|
|
|
for i, stride in enumerate(strides):
|
|
reg = outputs[i * 3 + 0][0]
|
|
cls = outputs[i * 3 + 1][0]
|
|
obj = outputs[i * 3 + 2][0]
|
|
|
|
num_classes, H, W = cls.shape
|
|
|
|
for h in range(H):
|
|
for w in range(W):
|
|
class_prob = cls[:, h, w]
|
|
cls_id = int(np.argmax(class_prob))
|
|
cls_score = class_prob[cls_id]
|
|
|
|
obj_score = obj[0, h, w]
|
|
score = cls_score * obj_score
|
|
|
|
if score < OBJ_THRESH:
|
|
continue
|
|
|
|
l, t, r, b = dfl_decode(reg[:, h, w])
|
|
|
|
cx = (w + 0.5) * stride
|
|
cy = (h + 0.5) * stride
|
|
|
|
x1 = cx - l * stride
|
|
y1 = cy - t * stride
|
|
x2 = cx + r * stride
|
|
y2 = cy + b * stride
|
|
|
|
boxes_all.append([x1, y1, x2, y2])
|
|
scores_all.append(score)
|
|
classes_all.append(cls_id)
|
|
|
|
if not boxes_all:
|
|
return None, None, None
|
|
|
|
keep = nms(boxes_all, scores_all, NMS_THRESH)
|
|
|
|
boxes = np.array(boxes_all)[keep]
|
|
scores = np.array(scores_all)[keep]
|
|
classes = np.array(classes_all)[keep]
|
|
|
|
boxes[:, [0, 2]] = (boxes[:, [0, 2]] - dx) / scale
|
|
boxes[:, [1, 3]] = (boxes[:, [1, 3]] - dy) / scale
|
|
|
|
return boxes, classes, scores
|
|
|
|
# ====================== RKNN 初始化(全局一次) ======================
|
|
_rknn = RKNNLite()
|
|
_rknn.load_rknn(MODEL_PATH)
|
|
_rknn.init_runtime()
|
|
|
|
# ====================== 统一接口函数 ======================
|
|
def detect_bag(img, return_vis=False):
|
|
"""
|
|
Args:
|
|
img (np.ndarray): BGR 原图
|
|
return_vis (bool)
|
|
|
|
Returns:
|
|
cls (str | None)
|
|
conf (float | None)
|
|
min_x (int | None)
|
|
vis_img (np.ndarray) # optional
|
|
"""
|
|
|
|
img_r, scale, dx, dy = letterbox_resize(img, IMG_SIZE)
|
|
outputs = _rknn.inference([np.expand_dims(img_r, 0)])
|
|
|
|
boxes, cls_ids, scores = post_process(outputs, scale, dx, dy)
|
|
|
|
if boxes is None or len(scores) == 0:
|
|
if return_vis:
|
|
return None, None, None, img.copy()
|
|
return None, None, None
|
|
|
|
best_idx = int(np.argmax(scores))
|
|
|
|
conf = float(scores[best_idx])
|
|
cls_id = int(cls_ids[best_idx])
|
|
cls = CLASS_NAME[cls_id]
|
|
|
|
x1, y1, x2, y2 = boxes[best_idx].astype(int)
|
|
min_x = int(x1)
|
|
|
|
if return_vis:
|
|
vis = img.copy()
|
|
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
cv2.putText(
|
|
vis,
|
|
f"{cls}:{conf:.3f}",
|
|
(x1, max(y1 - 5, 0)),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.6,
|
|
(0, 255, 0),
|
|
2
|
|
)
|
|
return cls, conf, min_x, vis
|
|
|
|
return cls, conf, min_x
|
|
|
|
|
|
# ====================== 测试 ======================
|
|
# ====================== 测试 ======================
|
|
if __name__ == "__main__":
|
|
IMG_PATH = "./test_image/4.jpg"
|
|
OUTPUT_DIR = "./result"
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
img = cv2.imread(IMG_PATH)
|
|
if img is None:
|
|
raise FileNotFoundError(IMG_PATH)
|
|
|
|
cls, conf, min_x, vis = detect_bag(img, return_vis=True)
|
|
|
|
if cls is None:
|
|
print("未检测到目标")
|
|
else:
|
|
print(f"类别: {cls}")
|
|
print(f"置信度: {conf:.4f}")
|
|
print(f"最左 x: {min_x}")
|
|
|
|
if vis is not None:
|
|
save_path = os.path.join(OUTPUT_DIR, "vis_result.jpg")
|
|
cv2.imwrite(save_path, vis)
|
|
print("可视化结果已保存:", save_path)
|
|
|
|
|