import os import cv2 import torch import argparse import numpy as np from ultralytics import YOLO from pathlib import Path # ====================== 配置参数 ======================3 MODEL_PATH = "best.pt" #SOURCE_IMG_DIR = "/home/hx/yolo/output_masks" # 原始输入图像目录 SOURCE_IMG_DIR = "/home/hx/yolo/yemian/test_image" # 原始输入图像目录 OUTPUT_DIR = "/home/hx/yolo/output_masks2" # 推理输出根目录 ROI_COORDS_FILE = "./roi_coordinates/1_rois2.txt" # 必须与训练时相同 CONF_THRESHOLD = 0.25 IOU_THRESHOLD = 0.45 DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" SAVE_TXT = True SAVE_MASKS = True VIEW_IMG = False LINE_WIDTH = 2 TARGET_SIZE = 640 # 必须与训练输入尺寸一致 def load_roi_coords(txt_path): """ 加载 ROI 坐标 (x, y, w, h) 支持多个 ROI(例如多视野检测) """ rois = [] if not os.path.exists(txt_path): raise FileNotFoundError(f"❌ ROI 文件未找到: {txt_path}") with open(txt_path, 'r') as f: for line in f: line = line.strip() if line and not line.startswith('#'): try: x, y, w, h = map(int, line.split(',')) rois.append((x, y, w, h)) print(f"📌 加载 ROI: (x={x}, y={y}, w={w}, h={h})") except Exception as e: print(f"⚠️ 无法解析 ROI 行: '{line}' | 错误: {e}") return rois def plot_result_with_opacity_and_roi( result, orig_img_shape, roi_box, line_width=2, mask_opacity=0.5 ): """ 绘制分割结果,支持透明叠加,并适配 ROI → 原图坐标的还原 :param result: YOLO 推理结果(在 cropped-resized 图像上) :param orig_img_shape: 原图 (H, W) :param roi_box: (x, y, w, h) 当前 ROI 区域 :return: 在原图上绘制的结果图像 """ h_orig, w_orig = orig_img_shape img = np.zeros((h_orig, w_orig, 3), dtype=np.uint8) # 创建黑底原图大小画布 x, y, w, h = roi_box # 获取模型输出的裁剪+resize后的图像上的结果 if result.masks is not None: masks = result.masks.data.cpu().numpy() # (N, 640, 640) # 将 mask resize 回 ROI 尺寸 resized_masks = [] for mask in masks: m = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) m = (m > 0.5).astype(np.uint8) resized_masks.append(m) resized_masks = np.array(resized_masks) # 上色叠加 overlay = img.copy() colors = np.random.randint(0, 255, size=(len(resized_masks), 3), dtype=np.uint8) for i, mask in enumerate(resized_masks): color = colors[i].tolist() roi_region = overlay[y:y+h, x:x+w] roi_region[mask == 1] = color cv2.addWeighted(overlay, mask_opacity, img, 1 - mask_opacity, 0, img) # 绘制边界框和标签 if result.boxes is not None: boxes = result.boxes.xyxy.cpu().numpy() classes = result.boxes.cls.cpu().numpy().astype(int) confidences = result.boxes.conf.cpu().numpy() colors = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8) font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 1 for i in range(len(boxes)): # 映射回 ROI 像素坐标 box = boxes[i] x1 = int(box[0] * w / TARGET_SIZE) + x y1 = int(box[1] * h / TARGET_SIZE) + y x2 = int(box[2] * w / TARGET_SIZE) + x y2 = int(box[3] * h / TARGET_SIZE) + y cls_id = classes[i] conf = confidences[i] color = colors[i].tolist() cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width) label = f"{cls_id} {conf:.2f}" (text_w, text_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) cv2.rectangle(img, (x1, y1 - text_h - 6), (x1 + text_w, y1), color, -1) cv2.putText(img, label, (x1, y1 - 4), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) return img def run_segmentation_inference_with_roi( model_path, source, output_dir, roi_coords_file, conf_threshold=0.25, iou_threshold=0.45, device="cuda:0", save_txt=True, save_masks=True, view_img=False, line_width=2, ): print(f"🚀 加载模型: {model_path}") print(f"💻 使用设备: {device}") # 加载模型 model = YOLO(model_path) # 加载 ROI rois = load_roi_coords(roi_coords_file) if len(rois) == 0: print("❌ 没有加载到任何有效的 ROI,程序退出。") return # 创建输出目录 output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) txt_dir = output_dir / "labels" mask_dir = output_dir / "masks" vis_dir = output_dir / "visualize" txt_dir.mkdir(parents=True, exist_ok=True) mask_dir.mkdir(parents=True, exist_ok=True) vis_dir.mkdir(parents=True, exist_ok=True) # ✅ 改这里:无条件创建! # 获取图像文件 source = Path(source) if source.is_file(): img_files = [source] else: img_files = list(source.glob("*.jpg")) + \ list(source.glob("*.jpeg")) + \ list(source.glob("*.png")) + \ list(source.glob("*.bmp")) if not img_files: print(f"❌ 在 {source} 中未找到图像文件") return print(f"🖼️ 共 {len(img_files)} 张图片待推理...") # 推理循环 for img_path in img_files: print(f"\n🔍 正在处理: {img_path.name}") orig_img = cv2.imread(str(img_path)) if orig_img is None: print(f"❌ 无法读取图像: {img_path}") continue h_orig, w_orig = orig_img.shape[:2] # 初始化全图级结果存储 full_vis_img = orig_img.copy() # 可视化用 all_txt_lines = [] # 对每个 ROI 进行推理 for idx, (x, y, w, h) in enumerate(rois): # 越界检查 if x < 0 or y < 0 or x + w > w_orig or y + h > h_orig: print(f"⚠️ ROI 越界,跳过: ({x},{y},{w},{h})") continue # 提取 ROI 并 resize 到模型输入尺寸 roi_img = orig_img[y:y+h, x:x+w] if roi_img.size == 0: print(f"⚠️ 空 ROI 区域: {idx}") continue resized_img = cv2.resize(roi_img, (TARGET_SIZE, TARGET_SIZE), interpolation=cv2.INTER_LINEAR) # 推理 results = model( source=resized_img, conf=conf_threshold, iou=iou_threshold, imgsz=TARGET_SIZE, device=device, verbose=False ) result = results[0] # 可视化:将结果映射回原图 vis_part = plot_result_with_opacity_and_roi( result=result, orig_img_shape=(h_orig, w_orig), roi_box=(x, y, w, h), line_width=line_width, mask_opacity=0.5 ) full_vis_img = cv2.addWeighted(full_vis_img, 1.0, vis_part, 0.7, 0) # 保存标签(YOLO 格式,归一化到原图) if save_txt and result.masks is not None: scale_x = w / TARGET_SIZE scale_y = h / TARGET_SIZE for i in range(len(result.boxes)): cls_id = int(result.boxes.cls[i]) seg = result.masks.xy[i] # (N, 2) in 640 space # 映射回 ROI 像素坐标 seg[:, 0] = seg[:, 0] * scale_x + x seg[:, 1] = seg[:, 1] * scale_y + y # 归一化到原图 seg_norm = seg / [w_orig, h_orig] seg_flat = seg_norm.flatten() line = f"{cls_id} {' '.join(f'{val:.6f}' for val in seg_flat)}" all_txt_lines.append(line) # 保存单个 ROI 的 mask(可选) if save_masks and result.masks is not None: mask_data = result.masks.data.cpu().numpy() combined_mask = (mask_data.sum(axis=0) > 0).astype(np.uint8) * 255 # resize 回 ROI 尺寸 roi_mask = cv2.resize(combined_mask, (w, h), interpolation=cv2.INTER_NEAREST) full_mask = np.zeros((h_orig, w_orig), dtype=np.uint8) full_mask[y:y+h, x:x+w] = roi_mask mask_save_path = mask_dir / f"mask_{img_path.stem}_roi{idx}.png" cv2.imwrite(str(mask_save_path), full_mask) # 保存最终可视化图像 vis_save_path = vis_dir / f"vis_{img_path.name}" cv2.imwrite(str(vis_save_path), full_vis_img) print(f"✅ 保存可视化结果: {vis_save_path}") # 保存合并的文本标签 if save_txt and all_txt_lines: txt_save_path = txt_dir / f"{img_path.stem}.txt" with open(txt_save_path, 'w') as f: f.write("\n".join(all_txt_lines) + "\n") print(f"📝 保存标签: {txt_save_path}") # 实时显示 if view_img: display_img = cv2.resize(full_vis_img, (960, 540)) cv2.imshow("Segmentation Result", display_img) if cv2.waitKey(0) == 27: # ESC cv2.destroyAllWindows() break if view_img: cv2.destroyAllWindows() print(f"\n🎉 推理完成!结果保存在: {output_dir}") # ====================== 主程序 ====================== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default=MODEL_PATH, help="模型权重路径") parser.add_argument("--source", default=SOURCE_IMG_DIR, help="图片路径或文件夹") parser.add_argument("--output", default=OUTPUT_DIR, help="输出目录") parser.add_argument("--roi-file", default=ROI_COORDS_FILE, help="ROI 坐标文件路径") parser.add_argument("--conf", type=float, default=CONF_THRESHOLD, help="置信度阈值") parser.add_argument("--iou", type=float, default=IOU_THRESHOLD, help="IoU 阈值") parser.add_argument("--device", default=DEVICE, help="设备: cuda:0, cpu") parser.add_argument("--view-img", action="store_true", help="显示图像") parser.add_argument("--save-txt", action="store_true", help="保存标签") parser.add_argument("--save-masks", action="store_true", help="保存掩码") opt = parser.parse_args() run_segmentation_inference_with_roi( model_path=opt.model, source=opt.source, output_dir=opt.output, roi_coords_file=opt.roi_file, conf_threshold=opt.conf, iou_threshold=opt.iou, device=opt.device, save_txt=opt.save_txt or SAVE_TXT, save_masks=opt.save_masks or SAVE_MASKS, view_img=opt.view_img, line_width=LINE_WIDTH, )