Files
xiantiao_CV/rknn-multi-threaded-nosigmoid/func_obb.py
琉璃月光 8506c3af79 first commit
2025-12-16 15:12:02 +08:00

215 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import math
# ---------- 配置 ----------
CLASSES = ['clamp']
nmsThresh = 0.4
objectThresh = 0.5
INPUT_DTYPE = np.uint8
DRAW_BOX = True
DRAW_SCORE = False
# ---------------- 工具函数 ----------------
class DetectBox:
def __init__(self, classId, score, xmin, ymin, xmax, ymax, angle):
self.classId = classId
self.score = score
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax
self.angle = angle
def letterbox_resize(image, size, bg_color=114):
tw, th = size
h, w = image.shape[:2]
scale = min(tw / w, th / h)
nw, nh = int(w * scale), int(h * scale)
img_resized = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
canvas = np.full((th, tw, 3), bg_color, dtype=np.uint8)
dx, dy = (tw - nw) // 2, (th - nh) // 2
canvas[dy:dy + nh, dx:dx + nw] = img_resized
return canvas, scale, dx, dy
def rotate_rectangle(x1, y1, x2, y2, a):
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
cos_a, sin_a = math.cos(a), math.sin(a)
pts = [[x1, y1], [x1, y2], [x2, y2], [x2, y1]]
return [
[
int(cx + (xx - cx) * cos_a - (yy - cy) * sin_a),
int(cy + (xx - cx) * sin_a + (yy - cy) * cos_a)
]
for xx, yy in pts
]
def polygon_area(pts):
x, y = zip(*pts)
return 0.5 * abs(sum(x[i] * y[(i+1)%4] - x[(i+1)%4]*y[i] for i in range(4)))
def polygon_intersection_area(p1, p2):
try:
from shapely.geometry import Polygon
poly1, poly2 = Polygon(p1), Polygon(p2)
if not poly1.is_valid: poly1 = poly1.buffer(0)
if not poly2.is_valid: poly2 = poly2.buffer(0)
inter = poly1.intersection(poly2)
return inter.area if not inter.is_empty else 0
except:
return 0
def IoU(b1, b2):
p1 = rotate_rectangle(b1.xmin, b1.ymin, b1.xmax, b1.ymax, b1.angle)
p2 = rotate_rectangle(b2.xmin, b2.ymin, b2.xmax, b2.ymax, b2.angle)
inter = polygon_intersection_area(p1, p2)
area1 = polygon_area(p1)
area2 = polygon_area(p2)
union = area1 + area2 - inter
return inter / union if union > 1e-6 else 0
def NMS(boxes):
if not boxes: return []
boxes = sorted(boxes, key=lambda x: x.score, reverse=True)
keep = []
for i, b1 in enumerate(boxes):
if b1.classId == -1: continue
keep.append(b1)
for j in range(i+1, len(boxes)):
b2 = boxes[j]
if b2.classId == b1.classId and IoU(b1, b2) > nmsThresh:
b2.classId = -1
return keep
from scipy.special import expit
def sigmoid(x): return expit(x)
def softmax(x, axis=-1):
x = np.asarray(x)
x_max = np.max(x, axis=axis, keepdims=True)
e = np.exp(x - x_max)
return e / (e.sum(axis=axis, keepdims=True) + 1e-9)
ARANGE16 = np.arange(16).reshape(1,1,16,1)
def process(out, mw, mh, stride, angle_feature, index, scale_w=1, scale_h=1):
angle_feature = angle_feature.reshape(-1)
xywh = out[:, :64, :]
conf = sigmoid(out[:, 64:, :]).reshape(-1)
boxes = []
class_num = len(CLASSES)
total = mh * mw * class_num
for ik in range(total):
if conf[ik] <= objectThresh: continue
w = ik % mw
h = (ik // mw) % mh
c = ik // (mw * mh)
xywh_ = xywh[0, :, h * mw + w].reshape(1,4,16,1)
xywh_ = softmax(xywh_, axis=2)
xywh_ = np.sum(xywh_ * ARANGE16, axis=2).reshape(-1)
xy_add = xywh_[:2] + xywh_[2:]
xy_sub = (xywh_[2:] - xywh_[:2]) / 2
angle = (angle_feature[index + h*mw + w] - 0.25) * math.pi
cos_a, sin_a = math.cos(angle), math.sin(angle)
xy_rot = np.array([
xy_sub[0]*cos_a - xy_sub[1]*sin_a,
xy_sub[0]*sin_a + xy_sub[1]*cos_a
])
cx = (xy_rot[0] + w + 0.5) * stride
cy = (xy_rot[1] + h + 0.5) * stride
w_box = xy_add[0] * stride
h_box = xy_add[1] * stride
xmin = (cx - w_box/2) * scale_w
ymin = (cy - h_box/2) * scale_h
xmax = (cx + w_box/2) * scale_w
ymax = (cy + h_box/2) * scale_h
boxes.append(DetectBox(c, float(conf[ik]), xmin, ymin, xmax, ymax, float(angle)))
return boxes
# ---------------- RKNN 推理接口 ----------------
def myFunc(rknn, frame, lock=None):
try:
# --- 推理前 letterbox ---
img_resized, scale, offset_x, offset_y = letterbox_resize(frame, (640,640))
infer_img = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
infer_input = np.expand_dims(infer_img.astype(INPUT_DTYPE), 0)
results = rknn.inference([infer_input])
if not results or len(results) < 1:
return frame
outputs = []
for x in results[:-1]:
if x is None: continue
# 决定 stride 与索引
if x.shape[2] == 20:
stride, index = 32, 20*4*20*4 + 20*2*20*2
elif x.shape[2] == 40:
stride, index = 16, 20*4*20*4
elif x.shape[2] == 80:
stride, index = 8, 0
else:
continue
feature = x.reshape(1,65,-1)
outputs += process(
feature,
x.shape[3],
x.shape[2],
stride,
results[-1],
index,
1.0, 1.0 # 输出坐标保持在640×640下
)
if not outputs:
return frame
predbox = NMS(outputs)
if len(predbox) < 2:
return frame
box1, box2 = sorted(predbox, key=lambda x: x.score, reverse=True)[:2]
out_frame = frame.copy()
# ========== 还原到原图坐标 ==========
def restore_to_original(b):
xmin = int((b.xmin - offset_x) / scale)
ymin = int((b.ymin - offset_y) / scale)
xmax = int((b.xmax - offset_x) / scale)
ymax = int((b.ymax - offset_y) / scale)
return xmin, ymin, xmax, ymax
for box in [box1, box2]:
xmin, ymin, xmax, ymax = restore_to_original(box)
if DRAW_BOX:
# 旋转框顶点(注意:旋转必须在原图坐标系)
pts = rotate_rectangle(xmin, ymin, xmax, ymax, box.angle)
cv2.polylines(out_frame, [np.array(pts, np.int32)], True, (0,255,0), 2)
if DRAW_SCORE:
cv2.putText(
out_frame, f"{box.score:.2f}",
(xmin, max(10, ymin - 6)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2
)
return out_frame
except Exception as e:
print(f"[func ❌] 推理异常: {e}")
return frame