from ultralytics import YOLO import cv2 import numpy as np import os # ------------------ 配置 ------------------ model_path = 'ultralytics_yolo11-main/runs/train/exp4/weights/best.pt' img_folder = '/home/hx/yolo/ultralytics_yolo11-main/dataset1/test' # 你的图片文件夹路径 output_mask_dir = 'output_masks1' os.makedirs(output_mask_dir, exist_ok=True) # 支持的图像格式 SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff') # ------------------ 加载模型 ------------------ model = YOLO(model_path) model.to('cuda') # 使用 GPU def process_image(img_path, output_dir): """处理单张图像:提取最长4条边,找出最接近垂直的2条""" img = cv2.imread(img_path) if img is None: print(f"❌ 无法读取图像: {img_path}") return h, w = img.shape[:2] filename = os.path.basename(img_path) name_only = os.path.splitext(filename)[0] print(f"\n🔄 正在处理: {filename}") # ------------------ 生成合成掩码 ------------------ composite_mask = np.zeros((h, w), dtype=np.uint8) results = model(img_path, imgsz=1280, conf=0.5) for r in results: if r.masks is not None: masks = r.masks.data.cpu().numpy() for mask in masks: mask_resized = cv2.resize(mask, (w, h)) mask_img = (mask_resized * 255).astype(np.uint8) composite_mask = np.maximum(composite_mask, mask_img) # 保存掩码 mask_save_path = os.path.join(output_dir, f'mask_{name_only}.png') cv2.imwrite(mask_save_path, composite_mask) print(f"✅ 掩码已保存: {mask_save_path}") # ------------------ 提取轮廓 ------------------ contours, _ = cv2.findContours(composite_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(contours) == 0: print(f"⚠️ 未检测到轮廓: {filename}") return all_contours = np.vstack(contours) epsilon = 0.005 * cv2.arcLength(all_contours, True) approx = cv2.approxPolyDP(all_contours, epsilon, True) pts = [p[0] for p in approx] # 展平为 (x, y) # ------------------ 提取边段 ------------------ line_segments = [] n = len(pts) for i in range(n): p1 = np.array(pts[i]) p2 = np.array(pts[(i + 1) % n]) dx = p2[0] - p1[0] dy = p2[1] - p1[1] length = np.linalg.norm([dx, dy]) if length < 20: continue # 忽略太短的边 # 计算“垂直相似度”:|dx| / |dy|,越小越接近垂直 vertical_score = float('inf') if abs(dy) < 1e-6 else abs(dx) / abs(dy) # 计算斜率 slope = float('inf') if abs(dx) < 1e-6 else dy / dx line_segments.append({ 'start': p1, 'end': p2, 'length': length, 'dx': dx, 'dy': dy, 'slope': slope, 'vertical_score': vertical_score # 越小越接近垂直 }) # 按长度排序,取最长的4条边 line_segments.sort(key=lambda x: x['length'], reverse=True) top_4_segments = line_segments[:4] if len(top_4_segments) < 2: print(f"⚠️ 不足2条有效边(共{len(top_4_segments)}条): {filename}") return # ------------------ 找出最接近垂直的2条边 ------------------ # 按 vertical_score 升序排序:值越小,越接近垂直 top_4_segments.sort(key=lambda x: x['vertical_score']) vertical_candidates = top_4_segments[:2] # 最接近垂直的两条边 # ------------------ 打印信息 ------------------ print(f"✅ 提取最长的4条边:") for i, seg in enumerate(top_4_segments): slope_val = seg['slope'] if isinstance(seg['slope'], float) else ( 'inf' if seg['slope'] == float('inf') else '-inf') print(f" 第{i + 1}长边: 长度={seg['length']:.1f}, 斜率={slope_val:.3f}, 垂直评分={seg['vertical_score']:.3f}") print(f"✅ 最接近垂直的两条边(垂直评分最小):") for i, seg in enumerate(vertical_candidates): slope_val = seg['slope'] if isinstance(seg['slope'], float) else ( 'inf' if seg['slope'] == float('inf') else '-inf') print(f" 垂直边{i + 1}: 长度={seg['length']:.1f}, 斜率={slope_val:.3f}, 垂直评分={seg['vertical_score']:.3f}") # ------------------ 可视化 ------------------ vis_img = img.copy() colors = [(0, 0, 255), (255, 0, 0)] # 红色, 蓝色 for i, seg in enumerate(vertical_candidates): p1, p2 = seg['start'], seg['end'] color = colors[i] cv2.line(vis_img, tuple(p1), tuple(p2), color, 3, cv2.LINE_AA) mid_x = (p1[0] + p2[0]) // 2 mid_y = (p1[1] + p2[1]) // 2 slope_val = 'inf' if not isinstance(seg['slope'], float) else f"{seg['slope']:.1f}" cv2.putText(vis_img, f"{seg['length']:.0f}, k={slope_val}", (mid_x, mid_y - 10 + i * 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) # 添加图例说明 cv2.putText(vis_img, "Red: 1st closest to vertical", (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1) cv2.putText(vis_img, "Blue: 2nd closest to vertical", (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 1) # 保存可视化图像 vis_save_path = os.path.join(output_dir, f'vertical_{name_only}.jpg') cv2.imwrite(vis_save_path, vis_img) print(f"✅ 可视化图像已保存: {vis_save_path}") # ------------------ 主程序:遍历文件夹 ------------------ if __name__ == '__main__': if not os.path.isdir(img_folder): print(f"❌ 图像文件夹不存在: {img_folder}") exit() image_files = [ f for f in os.listdir(img_folder) if f.lower().endswith(SUPPORTED_FORMATS) ] if len(image_files) == 0: print(f"⚠️ 在 {img_folder} 中未找到支持的图像文件") exit() print(f"✅ 发现 {len(image_files)} 张图像,开始批量处理...") for image_file in image_files: image_path = os.path.join(img_folder, image_file) process_image(image_path, output_mask_dir) print(f"\n🎉 所有图像处理完成!结果保存在: {output_mask_dir}")