137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
|
import cv2
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
import torch
|
|
|
|
# ====================== 配置 ======================
|
|
MODEL_PATH = "best.pt"
|
|
SOURCE_IMG_DIR = "/home/hx/yolo/yemian/test_image"
|
|
OUTPUT_DIR = "/home/hx/yolo/output_masks2"
|
|
CONF_THRESHOLD = 0.25
|
|
IOU_THRESHOLD = 0.45
|
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
SAVE_TXT = True
|
|
SAVE_MASKS = True
|
|
VIEW_IMG = False
|
|
LINE_WIDTH = 2
|
|
IMG_SIZE = 640 # YOLO 输入尺寸
|
|
|
|
# ====================== Letterbox 缩放函数 ======================
|
|
def letterbox_image(img, new_size=IMG_SIZE):
|
|
h, w = img.shape[:2]
|
|
scale = min(new_size / w, new_size / h)
|
|
new_w, new_h = int(w*scale), int(h*scale)
|
|
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
|
canvas = np.full((new_size, new_size, 3), 114, dtype=np.uint8)
|
|
pad_w, pad_h = new_size - new_w, new_size - new_h
|
|
pad_top, pad_left = pad_h // 2, pad_w // 2
|
|
canvas[pad_top:pad_top+new_h, pad_left:pad_left+new_w] = resized
|
|
return canvas, scale, pad_left, pad_top, new_w, new_h
|
|
|
|
# ====================== 绘制 mask & 边框 ======================
|
|
def plot_mask_on_image(result, orig_shape, scale, pad_left, pad_top, new_w, new_h, alpha=0.5):
|
|
H_ori, W_ori = orig_shape[:2]
|
|
img = np.zeros((H_ori, W_ori, 3), dtype=np.uint8)
|
|
|
|
if result.masks is not None and len(result.boxes) > 0:
|
|
masks = result.masks.data.cpu().numpy() # (N, IMG_SIZE, IMG_SIZE)
|
|
overlay = img.copy()
|
|
num_masks = len(masks)
|
|
colors = np.random.randint(0,255,(num_masks,3),dtype=np.uint8)
|
|
|
|
for i, mask in enumerate(masks):
|
|
# 去掉 padding
|
|
mask_crop = mask[pad_top:pad_top+new_h, pad_left:pad_left+new_w]
|
|
# resize 回原图
|
|
mask_orig = cv2.resize(mask_crop, (W_ori, H_ori), interpolation=cv2.INTER_NEAREST)
|
|
overlay[mask_orig>0.5] = colors[i].tolist()
|
|
|
|
cv2.addWeighted(overlay, alpha, img, 1-alpha, 0, img)
|
|
|
|
return img
|
|
|
|
# ====================== 主推理 ======================
|
|
def run_segmentation():
|
|
print(f"🚀 加载模型: {MODEL_PATH}")
|
|
model = YOLO(MODEL_PATH)
|
|
model.to(DEVICE)
|
|
|
|
source = Path(SOURCE_IMG_DIR)
|
|
output_dir = Path(OUTPUT_DIR)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
txt_dir = output_dir / "labels"
|
|
mask_dir = output_dir / "masks"
|
|
if SAVE_TXT: txt_dir.mkdir(exist_ok=True)
|
|
if SAVE_MASKS: mask_dir.mkdir(exist_ok=True)
|
|
|
|
img_files = list(source.glob("*.jpg")) + list(source.glob("*.png"))
|
|
if not img_files:
|
|
print(f"❌ 未找到图片")
|
|
return
|
|
print(f"🖼️ 待推理图片数量: {len(img_files)}")
|
|
|
|
for img_path in img_files:
|
|
print(f"🔍 推理: {img_path.name}")
|
|
orig_img = cv2.imread(str(img_path))
|
|
if orig_img is None:
|
|
print(" ❌ 读取失败")
|
|
continue
|
|
H_ori, W_ori = orig_img.shape[:2]
|
|
|
|
# Letterbox 缩放
|
|
img_input, scale, pad_left, pad_top, new_w, new_h = letterbox_image(orig_img, IMG_SIZE)
|
|
|
|
# YOLO 推理
|
|
results = model(img_input, conf=CONF_THRESHOLD, iou=IOU_THRESHOLD, imgsz=IMG_SIZE, device=DEVICE)
|
|
result = results[0]
|
|
|
|
# 可视化 mask
|
|
plotted = plot_mask_on_image(result, orig_img.shape, scale, pad_left, pad_top, new_w, new_h, alpha=0.5)
|
|
|
|
# 保存结果
|
|
save_path = output_dir / f"seg_{img_path.name}"
|
|
cv2.imwrite(str(save_path), plotted)
|
|
print(f"✅ 保存结果: {save_path}")
|
|
|
|
# 保存标签
|
|
if SAVE_TXT and result.masks is not None:
|
|
txt_path = txt_dir / f"{img_path.stem}.txt"
|
|
with open(txt_path,"w") as f:
|
|
for i in range(len(result.boxes)):
|
|
cls_id = int(result.boxes.cls[i])
|
|
seg = result.masks.xy[i].copy()
|
|
# 去掉 padding + scale 回原图
|
|
seg[:,0] = (seg[:,0] - pad_left) * (W_ori / new_w)
|
|
seg[:,1] = (seg[:,1] - pad_top) * (H_ori / new_h)
|
|
seg_norm = seg / [W_ori, H_ori]
|
|
seg_flat = seg_norm.flatten().tolist()
|
|
line = f"{cls_id} " + " ".join(f"{x:.6f}" for x in seg_flat) + "\n"
|
|
f.write(line)
|
|
print(f"📝 保存标签: {txt_path}")
|
|
|
|
# 保存 mask
|
|
if SAVE_MASKS and result.masks is not None:
|
|
masks = result.masks.data.cpu().numpy()
|
|
combined_mask = np.zeros((H_ori, W_ori), dtype=np.uint8)
|
|
for mask in masks:
|
|
mask_crop = mask[pad_top:pad_top+new_h, pad_left:pad_left+new_w]
|
|
mask_orig = cv2.resize(mask_crop, (W_ori, H_ori), interpolation=cv2.INTER_NEAREST)
|
|
combined_mask = np.maximum(combined_mask, (mask_orig>0.5).astype(np.uint8)*255)
|
|
mask_save_path = mask_dir / f"mask_{img_path.stem}.png"
|
|
cv2.imwrite(str(mask_save_path), combined_mask)
|
|
print(f"🎨 保存掩码: {mask_save_path}")
|
|
|
|
# 显示
|
|
if VIEW_IMG:
|
|
cv2.imshow("Segmentation", plotted)
|
|
if cv2.waitKey(0)==27:
|
|
cv2.destroyAllWindows()
|
|
break
|
|
|
|
print(f"🎉 推理完成!结果保存到: {output_dir}")
|
|
|
|
if __name__=="__main__":
|
|
run_segmentation()
|