import os import cv2 import torch import argparse import numpy as np from ultralytics import YOLO from pathlib import Path # ====================== 配置参数 ====================== MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/seg/exp7/weights/best.pt" SOURCE_IMG_DIR = "/home/hx/yolo/output_masks" OUTPUT_DIR = "/home/hx/yolo/output_masks2" CONF_THRESHOLD = 0.25 IOU_THRESHOLD = 0.45 DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" SAVE_TXT = True SAVE_MASKS = True VIEW_IMG = False LINE_WIDTH = 2 def plot_result_with_opacity(result, line_width=2, mask_opacity=0.5): """ 手动绘制 YOLO 分割结果,支持掩码透明度叠加,并修复掩码尺寸不匹配问题 """ img = result.orig_img.copy() # HWC, BGR h, w = img.shape[:2] # 获取原始图像尺寸 orig_shape = img.shape[:2] # (height, width) if result.masks is not None and len(result.boxes) > 0: # 将掩码从 GPU 移到 CPU 并转为 numpy masks = result.masks.data.cpu().numpy() # (N, H_mask, W_mask) # resize 掩码到原始图像尺寸 resized_masks = [] for mask in masks: mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) mask_resized = (mask_resized > 0.5).astype(np.uint8) # 二值化 resized_masks.append(mask_resized) resized_masks = np.array(resized_masks) # 随机颜色 (BGR) num_masks = len(result.boxes) colors = np.random.randint(0, 255, size=(num_masks, 3), dtype=np.uint8) # 创建叠加层 overlay = img.copy() for i in range(num_masks): color = colors[i].tolist() mask_resized = resized_masks[i] overlay[mask_resized == 1] = color # 透明叠加 cv2.addWeighted(overlay, mask_opacity, img, 1 - mask_opacity, 0, img) # 绘制边界框和标签(保持不变) if result.boxes is not None: boxes = result.boxes.xyxy.cpu().numpy() classes = result.boxes.cls.cpu().numpy().astype(int) confidences = result.boxes.conf.cpu().numpy() colors = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8) font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 1 for i in range(len(boxes)): box = boxes[i].astype(int) cls_id = classes[i] conf = confidences[i] color = colors[i].tolist() # 绘制矩形框 cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, line_width) # 标签文本 label = f"{cls_id} {conf:.2f}" # 获取文本大小 (text_w, text_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) text_h += baseline # 绘制标签背景 cv2.rectangle(img, (box[0], box[1] - text_h - 6), (box[0] + text_w, box[1]), color, -1) # 绘制文本 cv2.putText(img, label, (box[0], box[1] - 4), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) return img def run_segmentation_inference( model_path, source, output_dir, conf_threshold=0.25, iou_threshold=0.45, device="cuda:0", save_txt=True, save_masks=True, view_img=False, line_width=2, ): print(f"🚀 加载模型: {model_path}") print(f"💻 使用设备: {device}") # 加载模型 model = YOLO(model_path) # 创建输出目录 output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) txt_dir = output_dir / "labels" mask_dir = output_dir / "masks" if save_txt: txt_dir.mkdir(exist_ok=True) if save_masks: mask_dir.mkdir(exist_ok=True) # 获取图像文件列表 source = Path(source) if source.is_file(): img_files = [source] else: img_files = list(source.glob("*.jpg")) + \ list(source.glob("*.jpeg")) + \ list(source.glob("*.png")) + \ list(source.glob("*.bmp")) if not img_files: print(f"❌ 在 {source} 中未找到图像文件") return print(f"🖼️ 共 {len(img_files)} 张图片待推理...") # 推理循环 for img_path in img_files: print(f"🔍 推理: {img_path.name}") # 执行推理 results = model( source=str(img_path), conf=conf_threshold, iou=iou_threshold, imgsz=640, device=device, verbose=True ) result = results[0] orig_img = result.orig_img # 原始图像 # ✅ 使用自定义绘制函数(支持透明度) plotted = plot_result_with_opacity(result, line_width=line_width, mask_opacity=0.5) # 保存可视化图像 save_path = output_dir / f"seg_{img_path.name}" cv2.imwrite(str(save_path), plotted) print(f"✅ 保存结果: {save_path}") # 保存 YOLO 格式标签(多边形) if save_txt and result.masks is not None: txt_path = txt_dir / (img_path.stem + ".txt") with open(txt_path, 'w') as f: for i in range(len(result.boxes)): cls_id = int(result.boxes.cls[i]) seg = result.masks.xy[i] # 多边形点 (N, 2) seg = seg.flatten() seg = seg / [orig_img.shape[1], orig_img.shape[0]] # 归一化 seg = seg.tolist() line = f"{cls_id} {' '.join(f'{x:.6f}' for x in seg)}\n" f.write(line) print(f"📝 保存标签: {txt_path}") # 保存合并的掩码图 if save_masks and result.masks is not None: mask = result.masks.data.cpu().numpy() combined_mask = (mask.sum(axis=0) > 0).astype(np.uint8) * 255 # 合并所有掩码 mask_save_path = mask_dir / f"mask_{img_path.stem}.png" cv2.imwrite(str(mask_save_path), combined_mask) print(f"🎨 保存掩码: {mask_save_path}") # 实时显示(可选) if view_img: cv2.imshow("Segmentation Result", plotted) if cv2.waitKey(0) == 27: # ESC 退出 cv2.destroyAllWindows() break if view_img: cv2.destroyAllWindows() print(f"\n🎉 推理完成!结果保存在: {output_dir}") # ====================== 主程序 ====================== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default=MODEL_PATH, help="模型权重路径") parser.add_argument("--source", default=SOURCE_IMG_DIR, help="图片路径或文件夹") parser.add_argument("--output", default=OUTPUT_DIR, help="输出目录") parser.add_argument("--conf", type=float, default=CONF_THRESHOLD, help="置信度阈值") parser.add_argument("--iou", type=float, default=IOU_THRESHOLD, help="IoU 阈值") parser.add_argument("--device", default=DEVICE, help="设备: cuda:0, cpu") parser.add_argument("--view-img", action="store_true", help="显示图像") parser.add_argument("--save-txt", action="store_true", help="保存标签") parser.add_argument("--save-masks", action="store_true", help="保存掩码") opt = parser.parse_args() run_segmentation_inference( model_path=opt.model, source=opt.source, output_dir=opt.output, conf_threshold=opt.conf, iou_threshold=opt.iou, device=opt.device, save_txt=opt.save_txt, save_masks=opt.save_masks, view_img=opt.view_img, line_width=LINE_WIDTH, )