Files
zjsh_yolov11/ailai_obb/bushu_angle.py
2025-09-15 15:35:19 +08:00

198 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import math
from shapely.geometry import Polygon
from rknnlite.api import RKNNLite
import os
# ------------------- 配置 -------------------
CLASSES = ['clamp']
nmsThresh = 0.4
objectThresh = 0.5
# ------------------- 全局原图尺寸 -------------------
ORIG_W = 2560 # 原图宽
ORIG_H = 1440 # 原图高
# ------------------- 工具函数 -------------------
def letterbox_resize(image, size, bg_color=114):
target_width, target_height = size
image_height, image_width, _ = image.shape
scale = min(target_width / image_width, target_height / image_height)
new_width, new_height = int(image_width * scale), int(image_height * scale)
image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * bg_color
offset_x, offset_y = (target_width - new_width) // 2, (target_height - new_height) // 2
canvas[offset_y:offset_y + new_height, offset_x:offset_x + new_width] = image_resized
return canvas, scale, offset_x, offset_y
class DetectBox:
def __init__(self, classId, score, xmin, ymin, xmax, ymax, angle):
self.classId = classId
self.score = score
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax
self.angle = angle
def rotate_rectangle(x1, y1, x2, y2, a):
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
x1_new = int((x1 - cx) * math.cos(a) - (y1 - cy) * math.sin(a) + cx)
y1_new = int((x1 - cx) * math.sin(a) + (y1 - cy) * math.cos(a) + cy)
x2_new = int((x2 - cx) * math.cos(a) - (y2 - cy) * math.sin(a) + cx)
y2_new = int((x2 - cx) * math.sin(a) + (y2 - cy) * math.cos(a) + cy)
x3_new = int((x1 - cx) * math.cos(a) - (y2 - cy) * math.sin(a) + cx)
y3_new = int((x1 - cx) * math.sin(a) + (y2 - cy) * math.cos(a) + cy)
x4_new = int((x2 - cx) * math.cos(a) - (y1 - cy) * math.sin(a) + cx)
y4_new = int((x2 - cx) * math.sin(a) + (y1 - cy) * math.cos(a) + cy)
return [(x1_new, y1_new), (x3_new, y3_new), (x2_new, y2_new), (x4_new, y4_new)]
def intersection(g, p):
g = Polygon(np.array(g).reshape(-1,2))
p = Polygon(np.array(p).reshape(-1,2))
if not g.is_valid or not p.is_valid:
return 0
inter = g.intersection(p).area
union = g.area + p.area - inter
return 0 if union == 0 else inter / union
def NMS(detectResult):
predBoxs = []
sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True)
for i in range(len(sort_detectboxs)):
if sort_detectboxs[i].classId == -1:
continue
p1 = rotate_rectangle(sort_detectboxs[i].xmin, sort_detectboxs[i].ymin,
sort_detectboxs[i].xmax, sort_detectboxs[i].ymax,
sort_detectboxs[i].angle)
predBoxs.append(sort_detectboxs[i])
for j in range(i + 1, len(sort_detectboxs)):
if sort_detectboxs[j].classId == sort_detectboxs[i].classId:
p2 = rotate_rectangle(sort_detectboxs[j].xmin, sort_detectboxs[j].ymin,
sort_detectboxs[j].xmax, sort_detectboxs[j].ymax,
sort_detectboxs[j].angle)
if intersection(p1, p2) > nmsThresh:
sort_detectboxs[j].classId = -1
return predBoxs
def sigmoid(x):
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
def softmax(x, axis=-1):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
# ------------------- 关键修改process函数加入scale -------------------
def process(out, model_w, model_h, stride, angle_feature, index, scale=1.0, offset_x=0, offset_y=0):
class_num = len(CLASSES)
angle_feature = angle_feature.reshape(-1)
xywh = out[:, :64, :]
conf = sigmoid(out[:, 64:, :]).reshape(-1)
boxes = []
for ik in range(model_h * model_w * class_num):
if conf[ik] > objectThresh:
w = ik % model_w
h = (ik % (model_w * model_h)) // model_w
c = ik // (model_w * model_h)
# 解析xywh
xywh_ = xywh[0, :, (h * model_w) + w].reshape(1, 4, 16, 1)
xywh_ = softmax(xywh_, 2)
data = np.arange(16).reshape(1, 1, 16, 1)
xywh_ = np.sum(xywh_ * data, axis=2).reshape(-1)
xywh_add = xywh_[:2] + xywh_[2:]
xywh_sub = (xywh_[2:] - xywh_[:2]) / 2
# 取角度
angle_idx = min(index + (h * model_w) + w, len(angle_feature) - 1)
angle = (angle_feature[angle_idx] - 0.25) * math.pi
cos_a, sin_a = math.cos(angle), math.sin(angle)
xy = xywh_sub[0] * cos_a - xywh_sub[1] * sin_a, xywh_sub[0] * sin_a + xywh_sub[1] * cos_a
xywh1 = np.array([xy[0] + w + 0.5, xy[1] + h + 0.5, xywh_add[0], xywh_add[1]])
xywh1 *= stride
# 映射回原图坐标
xmin = (xywh1[0] - xywh1[2]/2 - offset_x) / scale
ymin = (xywh1[1] - xywh1[3]/2 - offset_y) / scale
xmax = (xywh1[0] + xywh1[2]/2 - offset_x) / scale
ymax = (xywh1[1] + xywh1[3]/2 - offset_y) / scale
boxes.append(DetectBox(c, conf[ik], xmin, ymin, xmax, ymax, angle))
return boxes
# ------------------- 新可调用函数 -------------------
def detect_boxes_rknn(model_path, image_path):
img = cv2.imread(image_path)
if img is None:
print(f"❌ 无法读取图像: {image_path}")
return None, None
img_resized, scale, offset_x, offset_y = letterbox_resize(img, (640, 640))
infer_img = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
infer_img = np.expand_dims(infer_img, 0)
rknn_lite = RKNNLite(verbose=False)
rknn_lite.load_rknn(model_path)
rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
results = rknn_lite.inference([infer_img])
detect_boxes = []
for x in results[:-1]:
index, stride = 0, 0
if x.shape[2] == 20:
stride, index = 32, 20*4*20*4 + 20*2*20*2
elif x.shape[2] == 40:
stride, index = 16, 20*4*20*4
elif x.shape[2] == 80:
stride, index = 8, 0
feature = x.reshape(1, 65, -1)
detect_boxes += process(feature, x.shape[3], x.shape[2], stride, results[-1], index,
scale=scale, offset_x=offset_x, offset_y=offset_y)
detect_boxes = NMS(detect_boxes)
rknn_lite.release()
return detect_boxes, img
# ------------------- 绘制与辅助函数 -------------------
def get_angles(detect_boxes):
return [box.angle for box in detect_boxes]
def draw_boxes(img, detect_boxes, save_path=None):
for box in detect_boxes:
points = rotate_rectangle(box.xmin, box.ymin, box.xmax, box.ymax, box.angle)
cv2.polylines(img, [np.array(points, np.int32)], True, (0, 255, 0), 1)
cv2.putText(img, f"{np.degrees(box.angle):.1f}°", (int(box.xmin), int(box.ymin)-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 1)
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, img)
print(f"✅ 带角度的检测结果已保存到 {save_path}")
return img
def visualize_top_box(img, detect_boxes, save_path=None):
if not detect_boxes:
return img
top_box = max(detect_boxes, key=lambda x: x.score)
points = rotate_rectangle(top_box.xmin, top_box.ymin, top_box.xmax, top_box.ymax, top_box.angle)
cv2.polylines(img, [np.array(points, np.int32)], True, (0, 255, 0), 2)
cv2.putText(img, f"{np.degrees(top_box.angle):.1f}°", (int(top_box.xmin), int(top_box.ymin)-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2)
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, img)
return img
# ------------------- 使用示例 -------------------
if __name__ == "__main__":
model_path = "obb.rknn"
image_path = "2.jpg"
detect_boxes, img = detect_boxes_rknn(model_path, image_path)
angles = get_angles(detect_boxes)
for i, angle in enumerate(angles):
print(f"{i+1}: angle = {angle:.4f} rad ({np.degrees(angle):.2f}°)")
save_path_all = "./inference_results/boxes_all.jpg"
draw_boxes(img.copy(), detect_boxes, save_path_all)
save_path_top = "./inference_results/top_box.jpg"
visualize_top_box(img.copy(), detect_boxes, save_path_top)