Files
zjsh_yolov11/angle_base_seg/bushu.py
2025-09-01 14:14:18 +08:00

207 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
import argparse
import torch
import torch.nn.functional as F
import torchvision
# ---------------- 配置 ----------------
OBJ_THRESH = 0.25
NMS_THRESH = 0.45
MAX_DETECT = 300
IMG_SIZE = (640, 640) # (W,H)
OUTPUT_DIR = "result"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ---------------- 工具函数 ----------------
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def dfl(position):
x = torch.tensor(position)
n, c, h, w = x.shape
y = x.reshape(n, 4, c // 4, h, w)
y = y.softmax(2)
acc_metrix = torch.arange(c // 4).float().reshape(1, 1, c // 4, 1, 1)
y = (y * acc_metrix).sum(2)
return y.numpy()
def box_process(position):
grid_h, grid_w = position.shape[2:4]
col, row = np.meshgrid(np.arange(0, grid_w), np.arange(0, grid_h))
col, row = col.reshape(1, 1, grid_h, grid_w), row.reshape(1, 1, grid_h, grid_w)
grid = np.concatenate((col, row), axis=1)
stride = np.array([IMG_SIZE[1] // grid_h, IMG_SIZE[0] // grid_w]).reshape(1, 2, 1, 1)
position = dfl(position)
box_xy = grid + 0.5 - position[:, 0:2, :, :]
box_xy2 = grid + 0.5 + position[:, 2:4, :, :]
xyxy = np.concatenate((box_xy * stride, box_xy2 * stride), axis=1)
return xyxy
def _crop_mask(masks, boxes):
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def post_process(input_data):
proto = input_data[-1]
boxes, scores, seg_part = [], [], []
default_branch = 3
pair_per_branch = len(input_data) // default_branch
for i in range(default_branch):
boxes.append(box_process(input_data[pair_per_branch * i]))
scores.append(np.ones_like(input_data[pair_per_branch * i + 1][:, :1, :, :], dtype=np.float32))
seg_part.append(input_data[pair_per_branch * i + 3])
def sp_flatten(_in):
ch = _in.shape[1]
_in = _in.transpose(0, 2, 3, 1)
return _in.reshape(-1, ch)
boxes = np.concatenate([sp_flatten(v) for v in boxes])
scores = np.concatenate([sp_flatten(v) for v in scores])
seg_part = np.concatenate([sp_flatten(v) for v in seg_part])
# 阈值过滤
keep = np.where(scores.reshape(-1) >= OBJ_THRESH)
boxes, scores, seg_part = boxes[keep], scores[keep], seg_part[keep]
# NMS
ids = torchvision.ops.nms(torch.tensor(boxes, dtype=torch.float32),
torch.tensor(scores, dtype=torch.float32), NMS_THRESH)
ids = ids.tolist()[:MAX_DETECT]
boxes, scores, seg_part = boxes[ids], scores[ids], seg_part[ids]
# mask decode
ph, pw = proto.shape[-2:]
proto = proto.reshape(seg_part.shape[-1], -1)
seg_img = np.matmul(seg_part, proto)
seg_img = sigmoid(seg_img)
seg_img = seg_img.reshape(-1, ph, pw)
seg_img = F.interpolate(torch.tensor(seg_img)[None], torch.Size([640, 640]), mode='bilinear', align_corners=False)[0]
seg_img_t = _crop_mask(seg_img, torch.tensor(boxes))
seg_img = seg_img_t.numpy() > 0.5
return boxes, scores, seg_img
# ---------------- 角度计算 ----------------
def compute_angle(boxes, seg_img, h, w, filename, mode="show"):
composite_mask = np.zeros((h, w), dtype=np.uint8)
jaws = []
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box)
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
obj_mask = np.zeros((h, w), dtype=np.uint8)
mask_resized = cv2.resize(seg_img[i].astype(np.uint8), (w, h))
obj_mask[y1:y2, x1:x2] = mask_resized[y1:y2, x1:x2] * 255
contours, _ = cv2.findContours(obj_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
continue
largest_contour = max(contours, key=cv2.contourArea)
area = cv2.contourArea(largest_contour)
if area < 100:
continue
rect = cv2.minAreaRect(largest_contour)
jaws.append({'contour': largest_contour, 'rect': rect, 'area': area})
composite_mask = np.maximum(composite_mask, obj_mask)
if len(jaws) < 2:
print(f"❌ 检测到的夹具少于2个{len(jaws)}个)")
return None
jaws.sort(key=lambda x: x['area'], reverse=True)
jaw1, jaw2 = jaws[0], jaws[1]
def get_long_edge_vector(rect):
center, (w_, h_), angle = rect
rad = np.radians(angle + (0 if w_ >= h_ else 90))
return np.array([np.cos(rad), np.sin(rad)])
def get_center(contour):
M = cv2.moments(contour)
return np.array([M['m10']/M['m00'], M['m01']/M['m00']]) if M['m00'] != 0 else np.array([0, 0])
dir1, dir2 = get_long_edge_vector(jaw1['rect']), get_long_edge_vector(jaw2['rect'])
center1, center2 = get_center(jaw1['contour']), get_center(jaw2['contour'])
fixture_center = (center1 + center2) / 2
to_center1, to_center2 = fixture_center - center1, fixture_center - center2
if np.linalg.norm(to_center1) > 1e-6 and np.dot(dir1, to_center1) < 0:
dir1 = -dir1
if np.linalg.norm(to_center2) > 1e-6 and np.dot(dir2, to_center2) < 0:
dir2 = -dir2
cos_angle = np.clip(np.dot(dir1, dir2), -1.0, 1.0)
angle = np.degrees(np.arccos(cos_angle))
opening_angle = min(angle, 180 - angle)
if mode == "show":
vis_img = np.stack([composite_mask]*3, axis=-1)
vis_img[composite_mask > 0] = [255, 255, 255]
box1, box2 = np.int32(cv2.boxPoints(jaw1['rect'])), np.int32(cv2.boxPoints(jaw2['rect']))
cv2.drawContours(vis_img, [box1], 0, (0, 0, 255), 2)
cv2.drawContours(vis_img, [box2], 0, (255, 0, 0), 2)
scale = 60
c1, c2 = tuple(np.int32(center1)), tuple(np.int32(center2))
end1, end2 = tuple(np.int32(center1 + scale * dir1)), tuple(np.int32(center2 + scale * dir2))
cv2.arrowedLine(vis_img, c1, end1, (0, 255, 0), 2, tipLength=0.3)
cv2.arrowedLine(vis_img, c2, end2, (0, 255, 0), 2, tipLength=0.3)
cv2.putText(vis_img, f"Angle: {opening_angle:.2f}°", (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
save_path = os.path.join(OUTPUT_DIR, f'angle_{filename}')
cv2.imwrite(save_path, vis_img)
print(f"✅ 结果已保存: {save_path}")
return round(opening_angle, 2)
# ---------------- 主程序 ----------------
def main():
# 固定路径(写死)
MODEL_PATH = "/userdata/bushu/seg.rknn"
IMG_PATH = "/userdata/bushu/test.jpg"
from py_utils.rknn_executor import RKNN_model_container
model = RKNN_model_container(MODEL_PATH, target='rk3588', device_id=None)
img_src = cv2.imread(IMG_PATH)
if img_src is None:
print("❌ 图片路径错误:", IMG_PATH)
return
h, w = img_src.shape[:2]
img = cv2.resize(img_src, IMG_SIZE)
outputs = model.run([img])
boxes, scores, seg_img = post_process(outputs)
filename = os.path.basename(IMG_PATH)
angle = compute_angle(boxes, seg_img, h, w, filename, mode="show")
if angle is not None:
print(f"🎉 检测到的夹具开合角度: {angle}°")
model.release()
if __name__ == "__main__":
main()