Files
ailai_image_point_diff/detect_image/detect_bag.py
2025-12-28 00:12:46 +08:00

182 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from rknnlite.api import RKNNLite
# ====================== 配置 ======================
MODEL_PATH = "bag3588.rknn"
IMG_PATH = "2.jpg"
IMG_SIZE = (640, 640)
OBJ_THRESH = 0.001
NMS_THRESH = 0.45
CLASS_NAME = ["bag"]
OUTPUT_DIR = "./result"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ====================== 全局 RKNN ======================
_global_rknn = None
def init_rknn(model_path):
global _global_rknn
if _global_rknn is None:
rknn = RKNNLite(verbose=False)
rknn.load_rknn(model_path)
rknn.init_runtime()
_global_rknn = rknn
return _global_rknn
# ====================== 工具函数 ======================
def letterbox_resize(image, size, bg_color=114):
target_w, target_h = size
h, w = image.shape[:2]
scale = min(target_w / w, target_h / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(image, (new_w, new_h))
canvas = np.full((target_h, target_w, 3), bg_color, dtype=np.uint8)
dx, dy = (target_w - new_w) // 2, (target_h - new_h) // 2
canvas[dy:dy + new_h, dx:dx + new_w] = resized
return canvas, scale, dx, dy
def dfl_numpy(position):
n, c, h, w = position.shape
p_num = 4
mc = c // p_num
y = position.reshape(n, p_num, mc, h, w)
y = np.exp(y) / np.sum(np.exp(y), axis=2, keepdims=True)
acc = np.arange(mc).reshape(1,1,mc,1,1)
y = np.sum(y * acc, axis=2)
return y
def box_process(position):
grid_h, grid_w = position.shape[2:4]
col, row = np.meshgrid(np.arange(grid_w), np.arange(grid_h))
col = col.reshape(1,1,grid_h,grid_w)
row = row.reshape(1,1,grid_h,grid_w)
grid = np.concatenate((col,row), axis=1)
stride = np.array([IMG_SIZE[1] // grid_h, IMG_SIZE[0] // grid_w]).reshape(1,2,1,1)
position = dfl_numpy(position)
box_xy = grid + 0.5 - position[:,0:2,:,:]
box_xy2 = grid + 0.5 + position[:,2:4,:,:]
xyxy = np.concatenate((box_xy*stride, box_xy2*stride), axis=1)
return xyxy
def filter_boxes(boxes, box_confidences, box_class_probs):
boxes = np.array(boxes).reshape(-1, 4)
box_confidences = np.array(box_confidences).reshape(-1)
box_class_probs = np.array(box_class_probs)
class_ids = np.argmax(box_class_probs, axis=-1)
class_scores = box_class_probs[np.arange(len(class_ids)), class_ids]
scores = box_confidences * class_scores
mask = scores >= OBJ_THRESH
if np.sum(mask) == 0:
return None, None, None, None
boxes = boxes[mask]
classes = class_ids[mask]
scores = scores[mask]
conf_keep = box_confidences[mask]
x1, y1, x2, y2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0, xx2 - xx1 + 1)
h = np.maximum(0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= NMS_THRESH)[0]
order = order[inds + 1]
return boxes[keep], classes[keep], scores[keep], conf_keep[keep]
def post_process(outputs, scale, dx, dy):
boxes_list, conf_list, class_list = [], [], []
branch_num = 3
for i in range(branch_num):
boxes_list.append(box_process(outputs[i*3]))
conf_list.append(outputs[i*3+2])
class_list.append(outputs[i*3+1])
def flatten(x):
ch = x.shape[1]
x = x.transpose(0,2,3,1)
return x.reshape(-1,ch)
boxes = np.concatenate([flatten(b) for b in boxes_list])
box_conf = np.concatenate([flatten(c) for c in conf_list])
class_probs = np.concatenate([flatten(c) for c in class_list])
boxes, classes, scores, conf_keep = filter_boxes(boxes, box_conf, class_probs)
if boxes is None:
return None, None, None, None
boxes[:, [0,2]] -= dx
boxes[:, [1,3]] -= dy
boxes /= scale
boxes = boxes.clip(min=0)
scores = 1-scores
conf_keep = conf_keep * 255
return boxes, classes, scores, conf_keep
# ====================== detect_bag ======================
def detect_bag(img, return_conf=True, return_vis=False):
rknn = init_rknn(MODEL_PATH)
img_resized, scale, dx, dy = letterbox_resize(img, IMG_SIZE)
input_data = np.expand_dims(img_resized, 0)
outputs = rknn.inference(inputs=[input_data])
boxes, classes, scores, conf_keep = post_process(outputs, scale, dx, dy)
if boxes is None or len(boxes) == 0:
return (None, None) if return_conf else (None,)
min_x = float(boxes[:,0].min())
conf_val = float(scores.max()) if return_conf else None
vis_img = None
if return_vis:
vis_img = img.copy()
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box.astype(int)
cls_id = classes[i]
score = scores[i]
cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(vis_img,
f"{CLASS_NAME[cls_id]}:{score:.1f}",
(x1, max(y1-5,0)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 0),
2)
save_path = os.path.join(OUTPUT_DIR, "vis_" + "result.jpg")
cv2.imwrite(save_path, vis_img)
if return_conf:
return conf_val, min_x
else:
return min_x, vis_img
# ====================== 测试 ======================
if __name__ == "__main__":
img = cv2.imread(IMG_PATH)
if img is None:
raise FileNotFoundError(f"图片无法读取: {IMG_PATH}")
# 可控制输出conf, vis
conf, min_x = detect_bag(img, return_conf=True, return_vis=True)
if conf is None:
print("❌ 未检测到 bag")
else:
print(f"✅ 最大置信度: {conf:.4f}, 最左 x: {min_x:.1f}")