215 lines
6.6 KiB
Python
215 lines
6.6 KiB
Python
import cv2
|
||
import numpy as np
|
||
import math
|
||
|
||
# ---------- 配置 ----------
|
||
CLASSES = ['clamp']
|
||
nmsThresh = 0.4
|
||
objectThresh = 0.5
|
||
INPUT_DTYPE = np.uint8
|
||
|
||
DRAW_BOX = True
|
||
DRAW_SCORE = False
|
||
|
||
# ---------------- 工具函数 ----------------
|
||
class DetectBox:
|
||
def __init__(self, classId, score, xmin, ymin, xmax, ymax, angle):
|
||
self.classId = classId
|
||
self.score = score
|
||
self.xmin = xmin
|
||
self.ymin = ymin
|
||
self.xmax = xmax
|
||
self.ymax = ymax
|
||
self.angle = angle
|
||
|
||
def letterbox_resize(image, size, bg_color=114):
|
||
tw, th = size
|
||
h, w = image.shape[:2]
|
||
scale = min(tw / w, th / h)
|
||
nw, nh = int(w * scale), int(h * scale)
|
||
img_resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
||
canvas = np.full((th, tw, 3), bg_color, dtype=np.uint8)
|
||
dx, dy = (tw - nw) // 2, (th - nh) // 2
|
||
canvas[dy:dy + nh, dx:dx + nw] = img_resized
|
||
return canvas, scale, dx, dy
|
||
|
||
def rotate_rectangle(x1, y1, x2, y2, a):
|
||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||
cos_a, sin_a = math.cos(a), math.sin(a)
|
||
pts = [[x1, y1], [x1, y2], [x2, y2], [x2, y1]]
|
||
return [
|
||
[
|
||
int(cx + (xx - cx) * cos_a - (yy - cy) * sin_a),
|
||
int(cy + (xx - cx) * sin_a + (yy - cy) * cos_a)
|
||
]
|
||
for xx, yy in pts
|
||
]
|
||
|
||
def polygon_area(pts):
|
||
x, y = zip(*pts)
|
||
return 0.5 * abs(sum(x[i] * y[(i+1)%4] - x[(i+1)%4]*y[i] for i in range(4)))
|
||
|
||
def polygon_intersection_area(p1, p2):
|
||
try:
|
||
from shapely.geometry import Polygon
|
||
poly1, poly2 = Polygon(p1), Polygon(p2)
|
||
if not poly1.is_valid: poly1 = poly1.buffer(0)
|
||
if not poly2.is_valid: poly2 = poly2.buffer(0)
|
||
inter = poly1.intersection(poly2)
|
||
return inter.area if not inter.is_empty else 0
|
||
except:
|
||
return 0
|
||
|
||
def IoU(b1, b2):
|
||
p1 = rotate_rectangle(b1.xmin, b1.ymin, b1.xmax, b1.ymax, b1.angle)
|
||
p2 = rotate_rectangle(b2.xmin, b2.ymin, b2.xmax, b2.ymax, b2.angle)
|
||
inter = polygon_intersection_area(p1, p2)
|
||
area1 = polygon_area(p1)
|
||
area2 = polygon_area(p2)
|
||
union = area1 + area2 - inter
|
||
return inter / union if union > 1e-6 else 0
|
||
|
||
def NMS(boxes):
|
||
if not boxes: return []
|
||
boxes = sorted(boxes, key=lambda x: x.score, reverse=True)
|
||
keep = []
|
||
for i, b1 in enumerate(boxes):
|
||
if b1.classId == -1: continue
|
||
keep.append(b1)
|
||
for j in range(i+1, len(boxes)):
|
||
b2 = boxes[j]
|
||
if b2.classId == b1.classId and IoU(b1, b2) > nmsThresh:
|
||
b2.classId = -1
|
||
return keep
|
||
|
||
from scipy.special import expit
|
||
def sigmoid(x): return expit(x)
|
||
|
||
def softmax(x, axis=-1):
|
||
x = np.asarray(x)
|
||
x_max = np.max(x, axis=axis, keepdims=True)
|
||
e = np.exp(x - x_max)
|
||
return e / (e.sum(axis=axis, keepdims=True) + 1e-9)
|
||
|
||
ARANGE16 = np.arange(16).reshape(1,1,16,1)
|
||
|
||
def process(out, mw, mh, stride, angle_feature, index, scale_w=1, scale_h=1):
|
||
angle_feature = angle_feature.reshape(-1)
|
||
xywh = out[:, :64, :]
|
||
conf = sigmoid(out[:, 64:, :]).reshape(-1)
|
||
boxes = []
|
||
class_num = len(CLASSES)
|
||
|
||
total = mh * mw * class_num
|
||
for ik in range(total):
|
||
if conf[ik] <= objectThresh: continue
|
||
w = ik % mw
|
||
h = (ik // mw) % mh
|
||
c = ik // (mw * mh)
|
||
|
||
xywh_ = xywh[0, :, h * mw + w].reshape(1,4,16,1)
|
||
xywh_ = softmax(xywh_, axis=2)
|
||
xywh_ = np.sum(xywh_ * ARANGE16, axis=2).reshape(-1)
|
||
|
||
xy_add = xywh_[:2] + xywh_[2:]
|
||
xy_sub = (xywh_[2:] - xywh_[:2]) / 2
|
||
|
||
angle = (angle_feature[index + h*mw + w] - 0.25) * math.pi
|
||
|
||
cos_a, sin_a = math.cos(angle), math.sin(angle)
|
||
xy_rot = np.array([
|
||
xy_sub[0]*cos_a - xy_sub[1]*sin_a,
|
||
xy_sub[0]*sin_a + xy_sub[1]*cos_a
|
||
])
|
||
|
||
cx = (xy_rot[0] + w + 0.5) * stride
|
||
cy = (xy_rot[1] + h + 0.5) * stride
|
||
w_box = xy_add[0] * stride
|
||
h_box = xy_add[1] * stride
|
||
|
||
xmin = (cx - w_box/2) * scale_w
|
||
ymin = (cy - h_box/2) * scale_h
|
||
xmax = (cx + w_box/2) * scale_w
|
||
ymax = (cy + h_box/2) * scale_h
|
||
|
||
boxes.append(DetectBox(c, float(conf[ik]), xmin, ymin, xmax, ymax, float(angle)))
|
||
|
||
return boxes
|
||
|
||
# ---------------- RKNN 推理接口 ----------------
|
||
def myFunc(rknn, frame, lock=None):
|
||
try:
|
||
# --- 推理前 letterbox ---
|
||
img_resized, scale, offset_x, offset_y = letterbox_resize(frame, (640,640))
|
||
infer_img = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
||
infer_input = np.expand_dims(infer_img.astype(INPUT_DTYPE), 0)
|
||
|
||
results = rknn.inference([infer_input])
|
||
if not results or len(results) < 1:
|
||
return frame
|
||
|
||
outputs = []
|
||
for x in results[:-1]:
|
||
if x is None: continue
|
||
# 决定 stride 与索引
|
||
if x.shape[2] == 20:
|
||
stride, index = 32, 20*4*20*4 + 20*2*20*2
|
||
elif x.shape[2] == 40:
|
||
stride, index = 16, 20*4*20*4
|
||
elif x.shape[2] == 80:
|
||
stride, index = 8, 0
|
||
else:
|
||
continue
|
||
|
||
feature = x.reshape(1,65,-1)
|
||
outputs += process(
|
||
feature,
|
||
x.shape[3],
|
||
x.shape[2],
|
||
stride,
|
||
results[-1],
|
||
index,
|
||
1.0, 1.0 # 输出坐标保持在640×640下
|
||
)
|
||
|
||
if not outputs:
|
||
return frame
|
||
|
||
predbox = NMS(outputs)
|
||
if len(predbox) < 2:
|
||
return frame
|
||
|
||
box1, box2 = sorted(predbox, key=lambda x: x.score, reverse=True)[:2]
|
||
|
||
out_frame = frame.copy()
|
||
|
||
# ========== 还原到原图坐标 ==========
|
||
def restore_to_original(b):
|
||
xmin = int((b.xmin - offset_x) / scale)
|
||
ymin = int((b.ymin - offset_y) / scale)
|
||
xmax = int((b.xmax - offset_x) / scale)
|
||
ymax = int((b.ymax - offset_y) / scale)
|
||
return xmin, ymin, xmax, ymax
|
||
|
||
for box in [box1, box2]:
|
||
xmin, ymin, xmax, ymax = restore_to_original(box)
|
||
|
||
if DRAW_BOX:
|
||
# 旋转框顶点(注意:旋转必须在原图坐标系)
|
||
pts = rotate_rectangle(xmin, ymin, xmax, ymax, box.angle)
|
||
cv2.polylines(out_frame, [np.array(pts, np.int32)], True, (0,255,0), 2)
|
||
|
||
if DRAW_SCORE:
|
||
cv2.putText(
|
||
out_frame, f"{box.score:.2f}",
|
||
(xmin, max(10, ymin - 6)),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2
|
||
)
|
||
|
||
return out_frame
|
||
|
||
except Exception as e:
|
||
print(f"[func ❌] 推理异常: {e}")
|
||
return frame
|
||
|