#!/usr/bin/env python3 import cv2 import numpy as np from pathlib import Path from ultralytics import YOLO import torch # ====================== 配置 ====================== MODEL_PATH = "best.pt" SOURCE_IMG_DIR = "/home/hx/yolo/yemian/test_image" OUTPUT_DIR = "/home/hx/yolo/output_masks2" CONF_THRESHOLD = 0.25 IOU_THRESHOLD = 0.45 DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" SAVE_TXT = True SAVE_MASKS = True VIEW_IMG = False LINE_WIDTH = 2 IMG_SIZE = 640 # YOLO 输入尺寸 # ====================== Letterbox 缩放函数 ====================== def letterbox_image(img, new_size=IMG_SIZE): h, w = img.shape[:2] scale = min(new_size / w, new_size / h) new_w, new_h = int(w*scale), int(h*scale) resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) canvas = np.full((new_size, new_size, 3), 114, dtype=np.uint8) pad_w, pad_h = new_size - new_w, new_size - new_h pad_top, pad_left = pad_h // 2, pad_w // 2 canvas[pad_top:pad_top+new_h, pad_left:pad_left+new_w] = resized return canvas, scale, pad_left, pad_top, new_w, new_h # ====================== 绘制 mask & 边框 ====================== def plot_mask_on_image(result, orig_shape, scale, pad_left, pad_top, new_w, new_h, alpha=0.5): H_ori, W_ori = orig_shape[:2] img = np.zeros((H_ori, W_ori, 3), dtype=np.uint8) if result.masks is not None and len(result.boxes) > 0: masks = result.masks.data.cpu().numpy() # (N, IMG_SIZE, IMG_SIZE) overlay = img.copy() num_masks = len(masks) colors = np.random.randint(0,255,(num_masks,3),dtype=np.uint8) for i, mask in enumerate(masks): # 去掉 padding mask_crop = mask[pad_top:pad_top+new_h, pad_left:pad_left+new_w] # resize 回原图 mask_orig = cv2.resize(mask_crop, (W_ori, H_ori), interpolation=cv2.INTER_NEAREST) overlay[mask_orig>0.5] = colors[i].tolist() cv2.addWeighted(overlay, alpha, img, 1-alpha, 0, img) return img # ====================== 主推理 ====================== def run_segmentation(): print(f"🚀 加载模型: {MODEL_PATH}") model = YOLO(MODEL_PATH) model.to(DEVICE) source = Path(SOURCE_IMG_DIR) output_dir = Path(OUTPUT_DIR) output_dir.mkdir(parents=True, exist_ok=True) txt_dir = output_dir / "labels" mask_dir = output_dir / "masks" if SAVE_TXT: txt_dir.mkdir(exist_ok=True) if SAVE_MASKS: mask_dir.mkdir(exist_ok=True) img_files = list(source.glob("*.jpg")) + list(source.glob("*.png")) if not img_files: print(f"❌ 未找到图片") return print(f"🖼️ 待推理图片数量: {len(img_files)}") for img_path in img_files: print(f"🔍 推理: {img_path.name}") orig_img = cv2.imread(str(img_path)) if orig_img is None: print(" ❌ 读取失败") continue H_ori, W_ori = orig_img.shape[:2] # Letterbox 缩放 img_input, scale, pad_left, pad_top, new_w, new_h = letterbox_image(orig_img, IMG_SIZE) # YOLO 推理 results = model(img_input, conf=CONF_THRESHOLD, iou=IOU_THRESHOLD, imgsz=IMG_SIZE, device=DEVICE) result = results[0] # 可视化 mask plotted = plot_mask_on_image(result, orig_img.shape, scale, pad_left, pad_top, new_w, new_h, alpha=0.5) # 保存结果 save_path = output_dir / f"seg_{img_path.name}" cv2.imwrite(str(save_path), plotted) print(f"✅ 保存结果: {save_path}") # 保存标签 if SAVE_TXT and result.masks is not None: txt_path = txt_dir / f"{img_path.stem}.txt" with open(txt_path,"w") as f: for i in range(len(result.boxes)): cls_id = int(result.boxes.cls[i]) seg = result.masks.xy[i].copy() # 去掉 padding + scale 回原图 seg[:,0] = (seg[:,0] - pad_left) * (W_ori / new_w) seg[:,1] = (seg[:,1] - pad_top) * (H_ori / new_h) seg_norm = seg / [W_ori, H_ori] seg_flat = seg_norm.flatten().tolist() line = f"{cls_id} " + " ".join(f"{x:.6f}" for x in seg_flat) + "\n" f.write(line) print(f"📝 保存标签: {txt_path}") # 保存 mask if SAVE_MASKS and result.masks is not None: masks = result.masks.data.cpu().numpy() combined_mask = np.zeros((H_ori, W_ori), dtype=np.uint8) for mask in masks: mask_crop = mask[pad_top:pad_top+new_h, pad_left:pad_left+new_w] mask_orig = cv2.resize(mask_crop, (W_ori, H_ori), interpolation=cv2.INTER_NEAREST) combined_mask = np.maximum(combined_mask, (mask_orig>0.5).astype(np.uint8)*255) mask_save_path = mask_dir / f"mask_{img_path.stem}.png" cv2.imwrite(str(mask_save_path), combined_mask) print(f"🎨 保存掩码: {mask_save_path}") # 显示 if VIEW_IMG: cv2.imshow("Segmentation", plotted) if cv2.waitKey(0)==27: cv2.destroyAllWindows() break print(f"🎉 推理完成!结果保存到: {output_dir}") if __name__=="__main__": run_segmentation()