Files
zjsh_yolov11/angle_base_seg/test_seg_angle_f.py
2025-09-05 14:29:33 +08:00

210 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import numpy as np
import os
# ------------------ 配置 ------------------
model_path = '../ultralytics_yolo11-main/runs/train/exp4/weights/best.pt'
#img_folder = '/home/hx/yolo/ultralytics_yolo11-main/dataset1/test'
img_folder = '/home/hx/yolo/output_masks'
output_mask_dir = 'output_masks1'
os.makedirs(output_mask_dir, exist_ok=True)
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
# ------------------ 加载模型 ------------------
model = YOLO(model_path)
model.to('cuda') # 使用 GPU如有
def get_orientation_vector(contour):
"""
使用 cv2.fitLine 计算轮廓的主方向(单位向量)
返回:主方向单位向量 (2,)
"""
if len(contour) < 5:
return np.array([1.0, 0.0]) # 默认方向:沿 x 轴
[vx, vy, _, _] = cv2.fitLine(contour, cv2.DIST_L2, 0, 0.01, 0.01)
direction = np.array([vx[0], vy[0]]) # 主轴方向
norm = np.linalg.norm(direction)
return direction / norm if norm > 1e-8 else direction
def get_contour_center(contour):
"""计算轮廓质心"""
M = cv2.moments(contour)
if M["m00"] == 0:
return np.array([0, 0])
return np.array([int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])])
def calculate_jaw_opening_angle(jaw1, jaw2):
"""
计算夹具开合角度,并返回修正后的方向向量
返回: (angle, dir1_final, dir2_final)
"""
center1 = get_contour_center(jaw1['contour'])
center2 = get_contour_center(jaw2['contour'])
fixture_center = np.array([(center1[0] + center2[0]) / 2.0, (center1[1] + center2[1]) / 2.0])
# ✅ 使用 fitLine 获取主方向
dir1_orig = get_orientation_vector(jaw1['contour'])
dir2_orig = get_orientation_vector(jaw2['contour'])
def correct_and_compute(d1, d2):
"""校正方向并计算夹角"""
# 校正 jaw1 方向:应指向 fixture_center
to_center1 = fixture_center - center1
if np.linalg.norm(to_center1) > 1e-6:
to_center1 = to_center1 / np.linalg.norm(to_center1)
if np.dot(d1, to_center1) < 0:
d1 = -d1 # 反向
# 校正 jaw2 方向
to_center2 = fixture_center - center2
if np.linalg.norm(to_center2) > 1e-6:
to_center2 = to_center2 / np.linalg.norm(to_center2)
if np.dot(d2, to_center2) < 0:
d2 = -d2
# 计算夹角
cos_angle = np.clip(np.dot(d1, d2), -1.0, 1.0)
angle = np.degrees(np.arccos(cos_angle))
return angle, d1, d2
# 尝试原始方向
angle_raw, dir1_raw, dir2_raw = correct_and_compute(dir1_orig, dir2_orig)
if angle_raw <= 170.0:
return angle_raw, dir1_raw, dir2_raw
print(f"⚠️ 初始角度过大: {angle_raw:.2f}°,尝试翻转 jaw2 方向...")
angle_corrected, dir1_corr, dir2_corr = correct_and_compute(dir1_orig, -dir2_orig)
print(f"🔄 方向修正后: {angle_corrected:.2f}°")
# 数值兜底:若仍过大,取补角
if angle_corrected > 170.0:
final_angle = 180.0 - angle_corrected
print(f"🔧 数值修正: {angle_corrected:.2f}° → {final_angle:.2f}°")
return final_angle, dir1_corr, dir2_corr
return angle_corrected, dir1_corr, dir2_corr
def process_image(img_path, output_dir):
img = cv2.imread(img_path)
if img is None:
print(f"❌ 无法读取图像: {img_path}")
return
h, w = img.shape[:2]
filename = os.path.basename(img_path)
name_only = os.path.splitext(filename)[0]
print(f"\n🔄 正在处理: {filename}")
# 创建单通道掩码
composite_mask = np.zeros((h, w), dtype=np.uint8)
results = model(img_path, imgsz=1280, conf=0.5)
rotated_rects = []
for r in results:
if r.masks is not None:
masks = r.masks.data.cpu().numpy()
boxes = r.boxes.xyxy.cpu().numpy()
for i, mask in enumerate(masks):
x1, y1, x2, y2 = map(int, boxes[i])
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
obj_mask = np.zeros((h, w), dtype=np.uint8)
mask_resized = cv2.resize(mask, (w, h))
obj_mask[y1:y2, x1:x2] = (mask_resized[y1:y2, x1:x2] * 255).astype(np.uint8)
contours, _ = cv2.findContours(obj_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
continue
largest_contour = max(contours, key=cv2.contourArea)
area = cv2.contourArea(largest_contour)
if area < 100:
continue
rect = cv2.minAreaRect(largest_contour)
rotated_rects.append({
'rect': rect,
'contour': largest_contour,
'area': area
})
composite_mask = np.maximum(composite_mask, obj_mask)
# 创建三通道可视化掩码
vis_mask = np.stack([composite_mask] * 3, axis=-1)
vis_mask[composite_mask > 0] = [255, 255, 255] # 白色前景
if len(rotated_rects) < 2:
print(f"⚠️ 检测到的对象少于2个{len(rotated_rects)}个): {filename}")
mask_save_path = os.path.join(output_dir, f'mask_{name_only}.png')
cv2.imwrite(mask_save_path, composite_mask)
print(f"✅ 掩码已保存(无足够夹具): {mask_save_path}")
return
# 按面积排序,取前两个
rotated_rects.sort(key=lambda x: x['area'], reverse=True)
jaw1, jaw2 = rotated_rects[0], rotated_rects[1]
# 计算角度和方向
opening_angle, dir1_final, dir2_final = calculate_jaw_opening_angle(jaw1, jaw2)
print(f"✅ 最终夹具开合角度: {opening_angle:.2f}°")
# ------------------ 可视化 ------------------
center1 = get_contour_center(jaw1['contour'])
center2 = get_contour_center(jaw2['contour'])
fixture_center = ((center1[0] + center2[0]) // 2, (center1[1] + center2[1]) // 2)
# 绘制最小外接矩形
box1 = cv2.boxPoints(jaw1['rect'])
box1 = np.int32(box1)
cv2.drawContours(vis_mask, [box1], 0, (0, 0, 255), 2) # jaw1: 红色
box2 = cv2.boxPoints(jaw2['rect'])
box2 = np.int32(box2)
cv2.drawContours(vis_mask, [box2], 0, (255, 0, 0), 2) # jaw2: 蓝色
# 绘制主方向箭头(绿色)
scale = 60
end1 = center1 + scale * dir1_final
end2 = center2 + scale * dir2_final
cv2.arrowedLine(vis_mask, tuple(center1), tuple(end1.astype(int)), (0, 255, 0), 2, tipLength=0.3)
cv2.arrowedLine(vis_mask, tuple(center2), tuple(end2.astype(int)), (0, 255, 0), 2, tipLength=0.3)
# 标注夹具中心(青色)
cv2.circle(vis_mask, fixture_center, 5, (255, 255, 0), -1)
cv2.putText(vis_mask, "Center", (fixture_center[0] + 10, fixture_center[1]),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1)
# 标注角度
cv2.putText(vis_mask, f"Angle: {opening_angle:.2f}°", (20, 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
# 保存结果
vis_save_path = os.path.join(output_dir, f'mask_with_angle_{name_only}.png')
cv2.imwrite(vis_save_path, vis_mask)
print(f"✅ 带角度标注的掩码图已保存: {vis_save_path}")
# ------------------ 主程序 ------------------
if __name__ == '__main__':
if not os.path.isdir(img_folder):
print(f"❌ 图像文件夹不存在: {img_folder}")
exit()
image_files = [f for f in os.listdir(img_folder) if f.lower().endswith(SUPPORTED_FORMATS)]
if len(image_files) == 0:
print(f"⚠️ 未找到支持的图像文件")
exit()
print(f"✅ 发现 {len(image_files)} 张图像,开始处理...")
for image_file in image_files:
image_path = os.path.join(img_folder, image_file)
process_image(image_path, output_mask_dir)
print(f"\n🎉 所有图像处理完成!结果保存在: {output_mask_dir}")