Files
zjsh_yolov11/angle_base_seg/angle_test_1.py

109 lines
3.2 KiB
Python
Raw Normal View History

2025-09-01 14:14:18 +08:00
from ultralytics import YOLO
import cv2
import numpy as np
import os
import torch
import torch.nn.functional as F
# ------------------ 模型与路径配置 ------------------
MODEL_PATH = '../ultralytics_yolo11-main/runs/train/seg/exp3/weights/best.pt'
OUTPUT_DIR = '../test_image'
os.makedirs(OUTPUT_DIR, exist_ok=True)
model = YOLO(MODEL_PATH)
model.to('cuda')
def detect_jaw_angle_fast(image_path, mode='silent'):
img = cv2.imread(image_path)
if img is None:
raise FileNotFoundError(f"无法读取图像: {image_path}")
h, w = img.shape[:2]
filename = os.path.basename(image_path)
name_only = os.path.splitext(filename)[0]
# 推理(批量可进一步提速)
results = model(image_path, imgsz=640, conf=0.5, device='cuda')
r = results[0]
if r.masks is None:
return None
# 【优化1】一次性上采样所有 masks
masks_tensor = r.masks.data # [N, h_m, w_m]
boxes = r.boxes.xyxy.cpu().numpy()
masks = F.interpolate(
masks_tensor.unsqueeze(0).float(),
size=(h, w),
mode='bilinear',
align_corners=False
)
masks = (masks[0] > 0.5).cpu().numpy().astype(np.uint8) # [N, h, w]
jaws = []
for i, (mask, box) in enumerate(zip(masks, boxes)):
x1, y1, x2, y2 = map(int, box)
# 【优化4】提前过滤小框
if (x2 - x1) * (y2 - y1) < 100:
continue
# 【优化2】裁剪区域
x1c, y1c = max(0, x1), max(0, y1)
x2c, y2c = min(w, x2), min(h, y2)
mask_crop = mask[y1c:y2c, x1c:x2c]
# 【优化3】使用 findNonZero + convexHull
coords = cv2.findNonZero(mask_crop)
if coords is None or len(coords) < 5:
continue
hull = cv2.convexHull(coords)
area = cv2.contourArea(hull)
if area < 100:
continue
rect = cv2.minAreaRect(hull)
jaws.append({'rect': rect, 'area': area})
if len(jaws) < 2:
return None
jaws = sorted(jaws, key=lambda x: x['area'], reverse=True)[:2]
rect1, rect2 = jaws[0]['rect'], jaws[1]['rect']
def get_angle(rect):
w, h = rect[1]
angle = rect[2]
return angle + 90 if w < h else angle
angle1 = get_angle(rect1) % 180
angle2 = get_angle(rect2) % 180
opening_angle = min(abs(angle1 - angle2), 180 - abs(angle1 - angle2))
# 可视化(可选)
if mode == 'show':
vis = np.zeros((h, w, 3), dtype=np.uint8)
box1 = cv2.boxPoints(rect1)
box2 = cv2.boxPoints(rect2)
cv2.drawContours(vis, [np.int32(box1)], 0, (0, 0, 255), 2)
cv2.drawContours(vis, [np.int32(box2)], 0, (255, 0, 0), 2)
cv2.putText(vis, f"{opening_angle:.1f}°", (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
cv2.imwrite(os.path.join(OUTPUT_DIR, f"fast_{name_only}.png"), vis)
return round(opening_angle, 2)
# ------------------ 测试 ------------------
if __name__ == '__main__':
image_path = '/home/hx/yolo/output_masks/2.jpg'
print(f"🚀 处理: {image_path}")
angle = detect_jaw_angle_fast(image_path, mode='show')
if angle is not None:
print(f"✅ 角度: {angle}°")
else:
print("❌ 未检测到两个夹具")