182 lines
5.9 KiB
Python
182 lines
5.9 KiB
Python
|
|
import os
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from rknnlite.api import RKNNLite
|
|||
|
|
|
|||
|
|
# ====================== 配置 ======================
|
|||
|
|
MODEL_PATH = "bag3588.rknn"
|
|||
|
|
IMG_PATH = "2.jpg"
|
|||
|
|
IMG_SIZE = (640, 640)
|
|||
|
|
OBJ_THRESH = 0.001
|
|||
|
|
NMS_THRESH = 0.45
|
|||
|
|
CLASS_NAME = ["bag"]
|
|||
|
|
OUTPUT_DIR = "./result"
|
|||
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|||
|
|
|
|||
|
|
# ====================== 全局 RKNN ======================
|
|||
|
|
_global_rknn = None
|
|||
|
|
|
|||
|
|
def init_rknn(model_path):
|
|||
|
|
global _global_rknn
|
|||
|
|
if _global_rknn is None:
|
|||
|
|
rknn = RKNNLite(verbose=False)
|
|||
|
|
rknn.load_rknn(model_path)
|
|||
|
|
rknn.init_runtime()
|
|||
|
|
_global_rknn = rknn
|
|||
|
|
return _global_rknn
|
|||
|
|
|
|||
|
|
# ====================== 工具函数 ======================
|
|||
|
|
def letterbox_resize(image, size, bg_color=114):
|
|||
|
|
target_w, target_h = size
|
|||
|
|
h, w = image.shape[:2]
|
|||
|
|
scale = min(target_w / w, target_h / h)
|
|||
|
|
new_w, new_h = int(w * scale), int(h * scale)
|
|||
|
|
resized = cv2.resize(image, (new_w, new_h))
|
|||
|
|
canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8)
|
|||
|
|
dx, dy = (target_w - new_w) // 2, (target_h - new_h) // 2
|
|||
|
|
canvas[dy:dy + new_h, dx:dx + new_w] = resized
|
|||
|
|
return canvas, scale, dx, dy
|
|||
|
|
|
|||
|
|
def dfl_numpy(position):
|
|||
|
|
n, c, h, w = position.shape
|
|||
|
|
p_num = 4
|
|||
|
|
mc = c // p_num
|
|||
|
|
y = position.reshape(n, p_num, mc, h, w)
|
|||
|
|
y = np.exp(y) / np.sum(np.exp(y), axis=2, keepdims=True)
|
|||
|
|
acc = np.arange(mc).reshape(1,1,mc,1,1)
|
|||
|
|
y = np.sum(y * acc, axis=2)
|
|||
|
|
return y
|
|||
|
|
|
|||
|
|
def box_process(position):
|
|||
|
|
grid_h, grid_w = position.shape[2:4]
|
|||
|
|
col, row = np.meshgrid(np.arange(grid_w), np.arange(grid_h))
|
|||
|
|
col = col.reshape(1,1,grid_h,grid_w)
|
|||
|
|
row = row.reshape(1,1,grid_h,grid_w)
|
|||
|
|
grid = np.concatenate((col,row), axis=1)
|
|||
|
|
stride = np.array([IMG_SIZE[1] // grid_h, IMG_SIZE[0] // grid_w]).reshape(1,2,1,1)
|
|||
|
|
position = dfl_numpy(position)
|
|||
|
|
box_xy = grid + 0.5 - position[:,0:2,:,:]
|
|||
|
|
box_xy2 = grid + 0.5 + position[:,2:4,:,:]
|
|||
|
|
xyxy = np.concatenate((box_xy*stride, box_xy2*stride), axis=1)
|
|||
|
|
return xyxy
|
|||
|
|
|
|||
|
|
def filter_boxes(boxes, box_confidences, box_class_probs):
|
|||
|
|
boxes = np.array(boxes).reshape(-1, 4)
|
|||
|
|
box_confidences = np.array(box_confidences).reshape(-1)
|
|||
|
|
box_class_probs = np.array(box_class_probs)
|
|||
|
|
|
|||
|
|
class_ids = np.argmax(box_class_probs, axis=-1)
|
|||
|
|
class_scores = box_class_probs[np.arange(len(class_ids)), class_ids]
|
|||
|
|
scores = box_confidences * class_scores
|
|||
|
|
|
|||
|
|
mask = scores >= OBJ_THRESH
|
|||
|
|
if np.sum(mask) == 0:
|
|||
|
|
return None, None, None, None
|
|||
|
|
|
|||
|
|
boxes = boxes[mask]
|
|||
|
|
classes = class_ids[mask]
|
|||
|
|
scores = scores[mask]
|
|||
|
|
conf_keep = box_confidences[mask]
|
|||
|
|
|
|||
|
|
x1, y1, x2, y2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3]
|
|||
|
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|||
|
|
order = scores.argsort()[::-1]
|
|||
|
|
keep = []
|
|||
|
|
while order.size > 0:
|
|||
|
|
i = order[0]
|
|||
|
|
keep.append(i)
|
|||
|
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
|||
|
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
|||
|
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
|||
|
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
|||
|
|
w = np.maximum(0, xx2 - xx1 + 1)
|
|||
|
|
h = np.maximum(0, yy2 - yy1 + 1)
|
|||
|
|
inter = w * h
|
|||
|
|
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
|||
|
|
inds = np.where(ovr <= NMS_THRESH)[0]
|
|||
|
|
order = order[inds + 1]
|
|||
|
|
return boxes[keep], classes[keep], scores[keep], conf_keep[keep]
|
|||
|
|
|
|||
|
|
def post_process(outputs, scale, dx, dy):
|
|||
|
|
boxes_list, conf_list, class_list = [], [], []
|
|||
|
|
branch_num = 3
|
|||
|
|
for i in range(branch_num):
|
|||
|
|
boxes_list.append(box_process(outputs[i*3]))
|
|||
|
|
conf_list.append(outputs[i*3+2])
|
|||
|
|
class_list.append(outputs[i*3+1])
|
|||
|
|
|
|||
|
|
def flatten(x):
|
|||
|
|
ch = x.shape[1]
|
|||
|
|
x = x.transpose(0,2,3,1)
|
|||
|
|
return x.reshape(-1,ch)
|
|||
|
|
|
|||
|
|
boxes = np.concatenate([flatten(b) for b in boxes_list])
|
|||
|
|
box_conf = np.concatenate([flatten(c) for c in conf_list])
|
|||
|
|
class_probs = np.concatenate([flatten(c) for c in class_list])
|
|||
|
|
|
|||
|
|
boxes, classes, scores, conf_keep = filter_boxes(boxes, box_conf, class_probs)
|
|||
|
|
if boxes is None:
|
|||
|
|
return None, None, None, None
|
|||
|
|
|
|||
|
|
boxes[:, [0,2]] -= dx
|
|||
|
|
boxes[:, [1,3]] -= dy
|
|||
|
|
boxes /= scale
|
|||
|
|
boxes = boxes.clip(min=0)
|
|||
|
|
|
|||
|
|
scores = 1-scores
|
|||
|
|
conf_keep = conf_keep * 255
|
|||
|
|
return boxes, classes, scores, conf_keep
|
|||
|
|
|
|||
|
|
# ====================== detect_bag ======================
|
|||
|
|
def detect_bag(img, return_conf=True, return_vis=False):
|
|||
|
|
rknn = init_rknn(MODEL_PATH)
|
|||
|
|
|
|||
|
|
img_resized, scale, dx, dy = letterbox_resize(img, IMG_SIZE)
|
|||
|
|
input_data = np.expand_dims(img_resized, 0)
|
|||
|
|
outputs = rknn.inference(inputs=[input_data])
|
|||
|
|
boxes, classes, scores, conf_keep = post_process(outputs, scale, dx, dy)
|
|||
|
|
|
|||
|
|
if boxes is None or len(boxes) == 0:
|
|||
|
|
return (None, None) if return_conf else (None,)
|
|||
|
|
|
|||
|
|
min_x = float(boxes[:,0].min())
|
|||
|
|
conf_val = float(scores.max()) if return_conf else None
|
|||
|
|
vis_img = None
|
|||
|
|
|
|||
|
|
if return_vis:
|
|||
|
|
vis_img = img.copy()
|
|||
|
|
for i, box in enumerate(boxes):
|
|||
|
|
x1, y1, x2, y2 = box.astype(int)
|
|||
|
|
cls_id = classes[i]
|
|||
|
|
score = scores[i]
|
|||
|
|
cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|||
|
|
cv2.putText(vis_img,
|
|||
|
|
f"{CLASS_NAME[cls_id]}:{score:.1f}",
|
|||
|
|
(x1, max(y1-5,0)),
|
|||
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|||
|
|
0.6,
|
|||
|
|
(0, 255, 0),
|
|||
|
|
2)
|
|||
|
|
save_path = os.path.join(OUTPUT_DIR, "vis_" + "result.jpg")
|
|||
|
|
cv2.imwrite(save_path, vis_img)
|
|||
|
|
|
|||
|
|
if return_conf:
|
|||
|
|
return conf_val, min_x
|
|||
|
|
else:
|
|||
|
|
return min_x, vis_img
|
|||
|
|
|
|||
|
|
# ====================== 测试 ======================
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
img = cv2.imread(IMG_PATH)
|
|||
|
|
if img is None:
|
|||
|
|
raise FileNotFoundError(f"图片无法读取: {IMG_PATH}")
|
|||
|
|
|
|||
|
|
# 可控制输出:conf, vis
|
|||
|
|
conf, min_x = detect_bag(img, return_conf=True, return_vis=True)
|
|||
|
|
|
|||
|
|
if conf is None:
|
|||
|
|
print("❌ 未检测到 bag")
|
|||
|
|
else:
|
|||
|
|
print(f"✅ 最大置信度: {conf:.4f}, 最左 x: {min_x:.1f}")
|
|||
|
|
|