bushu
This commit is contained in:
249
yemian/test.py
249
yemian/test.py
@ -1,14 +1,13 @@
|
||||
import os
|
||||
#!/usr/bin/env python3
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
# ====================== 配置参数 ======================
|
||||
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/seg/exp7/weights/best.pt"
|
||||
SOURCE_IMG_DIR = "/home/hx/yolo/output_masks"
|
||||
# ====================== 配置 ======================
|
||||
MODEL_PATH = "best.pt"
|
||||
SOURCE_IMG_DIR = "/home/hx/yolo/yemian/test_image"
|
||||
OUTPUT_DIR = "/home/hx/yolo/output_masks2"
|
||||
CONF_THRESHOLD = 0.25
|
||||
IOU_THRESHOLD = 0.45
|
||||
@ -17,207 +16,121 @@ SAVE_TXT = True
|
||||
SAVE_MASKS = True
|
||||
VIEW_IMG = False
|
||||
LINE_WIDTH = 2
|
||||
IMG_SIZE = 640 # YOLO 输入尺寸
|
||||
|
||||
def plot_result_with_opacity(result, line_width=2, mask_opacity=0.5):
|
||||
"""
|
||||
手动绘制 YOLO 分割结果,支持掩码透明度叠加,并修复掩码尺寸不匹配问题
|
||||
"""
|
||||
img = result.orig_img.copy() # HWC, BGR
|
||||
# ====================== Letterbox 缩放函数 ======================
|
||||
def letterbox_image(img, new_size=IMG_SIZE):
|
||||
h, w = img.shape[:2]
|
||||
scale = min(new_size / w, new_size / h)
|
||||
new_w, new_h = int(w*scale), int(h*scale)
|
||||
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
canvas = np.full((new_size, new_size, 3), 114, dtype=np.uint8)
|
||||
pad_w, pad_h = new_size - new_w, new_size - new_h
|
||||
pad_top, pad_left = pad_h // 2, pad_w // 2
|
||||
canvas[pad_top:pad_top+new_h, pad_left:pad_left+new_w] = resized
|
||||
return canvas, scale, pad_left, pad_top, new_w, new_h
|
||||
|
||||
# 获取原始图像尺寸
|
||||
orig_shape = img.shape[:2] # (height, width)
|
||||
# ====================== 绘制 mask & 边框 ======================
|
||||
def plot_mask_on_image(result, orig_shape, scale, pad_left, pad_top, new_w, new_h, alpha=0.5):
|
||||
H_ori, W_ori = orig_shape[:2]
|
||||
img = np.zeros((H_ori, W_ori, 3), dtype=np.uint8)
|
||||
|
||||
if result.masks is not None and len(result.boxes) > 0:
|
||||
# 将掩码从 GPU 移到 CPU 并转为 numpy
|
||||
masks = result.masks.data.cpu().numpy() # (N, H_mask, W_mask)
|
||||
|
||||
# resize 掩码到原始图像尺寸
|
||||
resized_masks = []
|
||||
for mask in masks:
|
||||
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
mask_resized = (mask_resized > 0.5).astype(np.uint8) # 二值化
|
||||
resized_masks.append(mask_resized)
|
||||
resized_masks = np.array(resized_masks)
|
||||
|
||||
# 随机颜色 (BGR)
|
||||
num_masks = len(result.boxes)
|
||||
colors = np.random.randint(0, 255, size=(num_masks, 3), dtype=np.uint8)
|
||||
|
||||
# 创建叠加层
|
||||
masks = result.masks.data.cpu().numpy() # (N, IMG_SIZE, IMG_SIZE)
|
||||
overlay = img.copy()
|
||||
for i in range(num_masks):
|
||||
color = colors[i].tolist()
|
||||
mask_resized = resized_masks[i]
|
||||
overlay[mask_resized == 1] = color
|
||||
num_masks = len(masks)
|
||||
colors = np.random.randint(0,255,(num_masks,3),dtype=np.uint8)
|
||||
|
||||
# 透明叠加
|
||||
cv2.addWeighted(overlay, mask_opacity, img, 1 - mask_opacity, 0, img)
|
||||
for i, mask in enumerate(masks):
|
||||
# 去掉 padding
|
||||
mask_crop = mask[pad_top:pad_top+new_h, pad_left:pad_left+new_w]
|
||||
# resize 回原图
|
||||
mask_orig = cv2.resize(mask_crop, (W_ori, H_ori), interpolation=cv2.INTER_NEAREST)
|
||||
overlay[mask_orig>0.5] = colors[i].tolist()
|
||||
|
||||
# 绘制边界框和标签(保持不变)
|
||||
if result.boxes is not None:
|
||||
boxes = result.boxes.xyxy.cpu().numpy()
|
||||
classes = result.boxes.cls.cpu().numpy().astype(int)
|
||||
confidences = result.boxes.conf.cpu().numpy()
|
||||
|
||||
colors = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 1
|
||||
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i].astype(int)
|
||||
cls_id = classes[i]
|
||||
conf = confidences[i]
|
||||
color = colors[i].tolist()
|
||||
|
||||
# 绘制矩形框
|
||||
cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, line_width)
|
||||
|
||||
# 标签文本
|
||||
label = f"{cls_id} {conf:.2f}"
|
||||
|
||||
# 获取文本大小
|
||||
(text_w, text_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
|
||||
text_h += baseline
|
||||
|
||||
# 绘制标签背景
|
||||
cv2.rectangle(img, (box[0], box[1] - text_h - 6), (box[0] + text_w, box[1]), color, -1)
|
||||
# 绘制文本
|
||||
cv2.putText(img, label, (box[0], box[1] - 4), font, font_scale,
|
||||
(255, 255, 255), thickness, cv2.LINE_AA)
|
||||
cv2.addWeighted(overlay, alpha, img, 1-alpha, 0, img)
|
||||
|
||||
return img
|
||||
|
||||
# ====================== 主推理 ======================
|
||||
def run_segmentation():
|
||||
print(f"🚀 加载模型: {MODEL_PATH}")
|
||||
model = YOLO(MODEL_PATH)
|
||||
model.to(DEVICE)
|
||||
|
||||
def run_segmentation_inference(
|
||||
model_path,
|
||||
source,
|
||||
output_dir,
|
||||
conf_threshold=0.25,
|
||||
iou_threshold=0.45,
|
||||
device="cuda:0",
|
||||
save_txt=True,
|
||||
save_masks=True,
|
||||
view_img=False,
|
||||
line_width=2,
|
||||
):
|
||||
print(f"🚀 加载模型: {model_path}")
|
||||
print(f"💻 使用设备: {device}")
|
||||
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(output_dir)
|
||||
source = Path(SOURCE_IMG_DIR)
|
||||
output_dir = Path(OUTPUT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
txt_dir = output_dir / "labels"
|
||||
mask_dir = output_dir / "masks"
|
||||
if save_txt:
|
||||
txt_dir.mkdir(exist_ok=True)
|
||||
if save_masks:
|
||||
mask_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 获取图像文件列表
|
||||
source = Path(source)
|
||||
if source.is_file():
|
||||
img_files = [source]
|
||||
else:
|
||||
img_files = list(source.glob("*.jpg")) + \
|
||||
list(source.glob("*.jpeg")) + \
|
||||
list(source.glob("*.png")) + \
|
||||
list(source.glob("*.bmp"))
|
||||
if SAVE_TXT: txt_dir.mkdir(exist_ok=True)
|
||||
if SAVE_MASKS: mask_dir.mkdir(exist_ok=True)
|
||||
|
||||
img_files = list(source.glob("*.jpg")) + list(source.glob("*.png"))
|
||||
if not img_files:
|
||||
print(f"❌ 在 {source} 中未找到图像文件")
|
||||
print(f"❌ 未找到图片")
|
||||
return
|
||||
print(f"🖼️ 待推理图片数量: {len(img_files)}")
|
||||
|
||||
print(f"🖼️ 共 {len(img_files)} 张图片待推理...")
|
||||
|
||||
# 推理循环
|
||||
for img_path in img_files:
|
||||
print(f"🔍 推理: {img_path.name}")
|
||||
orig_img = cv2.imread(str(img_path))
|
||||
if orig_img is None:
|
||||
print(" ❌ 读取失败")
|
||||
continue
|
||||
H_ori, W_ori = orig_img.shape[:2]
|
||||
|
||||
# 执行推理
|
||||
results = model(
|
||||
source=str(img_path),
|
||||
conf=conf_threshold,
|
||||
iou=iou_threshold,
|
||||
imgsz=640,
|
||||
device=device,
|
||||
verbose=True
|
||||
)
|
||||
# Letterbox 缩放
|
||||
img_input, scale, pad_left, pad_top, new_w, new_h = letterbox_image(orig_img, IMG_SIZE)
|
||||
|
||||
# YOLO 推理
|
||||
results = model(img_input, conf=CONF_THRESHOLD, iou=IOU_THRESHOLD, imgsz=IMG_SIZE, device=DEVICE)
|
||||
result = results[0]
|
||||
orig_img = result.orig_img # 原始图像
|
||||
|
||||
# ✅ 使用自定义绘制函数(支持透明度)
|
||||
plotted = plot_result_with_opacity(result, line_width=line_width, mask_opacity=0.5)
|
||||
# 可视化 mask
|
||||
plotted = plot_mask_on_image(result, orig_img.shape, scale, pad_left, pad_top, new_w, new_h, alpha=0.5)
|
||||
|
||||
# 保存可视化图像
|
||||
# 保存结果
|
||||
save_path = output_dir / f"seg_{img_path.name}"
|
||||
cv2.imwrite(str(save_path), plotted)
|
||||
print(f"✅ 保存结果: {save_path}")
|
||||
|
||||
# 保存 YOLO 格式标签(多边形)
|
||||
if save_txt and result.masks is not None:
|
||||
txt_path = txt_dir / (img_path.stem + ".txt")
|
||||
with open(txt_path, 'w') as f:
|
||||
# 保存标签
|
||||
if SAVE_TXT and result.masks is not None:
|
||||
txt_path = txt_dir / f"{img_path.stem}.txt"
|
||||
with open(txt_path,"w") as f:
|
||||
for i in range(len(result.boxes)):
|
||||
cls_id = int(result.boxes.cls[i])
|
||||
seg = result.masks.xy[i] # 多边形点 (N, 2)
|
||||
seg = seg.flatten()
|
||||
seg = seg / [orig_img.shape[1], orig_img.shape[0]] # 归一化
|
||||
seg = seg.tolist()
|
||||
line = f"{cls_id} {' '.join(f'{x:.6f}' for x in seg)}\n"
|
||||
seg = result.masks.xy[i].copy()
|
||||
# 去掉 padding + scale 回原图
|
||||
seg[:,0] = (seg[:,0] - pad_left) * (W_ori / new_w)
|
||||
seg[:,1] = (seg[:,1] - pad_top) * (H_ori / new_h)
|
||||
seg_norm = seg / [W_ori, H_ori]
|
||||
seg_flat = seg_norm.flatten().tolist()
|
||||
line = f"{cls_id} " + " ".join(f"{x:.6f}" for x in seg_flat) + "\n"
|
||||
f.write(line)
|
||||
print(f"📝 保存标签: {txt_path}")
|
||||
|
||||
# 保存合并的掩码图
|
||||
if save_masks and result.masks is not None:
|
||||
mask = result.masks.data.cpu().numpy()
|
||||
combined_mask = (mask.sum(axis=0) > 0).astype(np.uint8) * 255 # 合并所有掩码
|
||||
# 保存 mask
|
||||
if SAVE_MASKS and result.masks is not None:
|
||||
masks = result.masks.data.cpu().numpy()
|
||||
combined_mask = np.zeros((H_ori, W_ori), dtype=np.uint8)
|
||||
for mask in masks:
|
||||
mask_crop = mask[pad_top:pad_top+new_h, pad_left:pad_left+new_w]
|
||||
mask_orig = cv2.resize(mask_crop, (W_ori, H_ori), interpolation=cv2.INTER_NEAREST)
|
||||
combined_mask = np.maximum(combined_mask, (mask_orig>0.5).astype(np.uint8)*255)
|
||||
mask_save_path = mask_dir / f"mask_{img_path.stem}.png"
|
||||
cv2.imwrite(str(mask_save_path), combined_mask)
|
||||
print(f"🎨 保存掩码: {mask_save_path}")
|
||||
|
||||
# 实时显示(可选)
|
||||
if view_img:
|
||||
cv2.imshow("Segmentation Result", plotted)
|
||||
if cv2.waitKey(0) == 27: # ESC 退出
|
||||
# 显示
|
||||
if VIEW_IMG:
|
||||
cv2.imshow("Segmentation", plotted)
|
||||
if cv2.waitKey(0)==27:
|
||||
cv2.destroyAllWindows()
|
||||
break
|
||||
|
||||
if view_img:
|
||||
cv2.destroyAllWindows()
|
||||
print(f"\n🎉 推理完成!结果保存在: {output_dir}")
|
||||
print(f"🎉 推理完成!结果保存到: {output_dir}")
|
||||
|
||||
|
||||
# ====================== 主程序 ======================
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default=MODEL_PATH, help="模型权重路径")
|
||||
parser.add_argument("--source", default=SOURCE_IMG_DIR, help="图片路径或文件夹")
|
||||
parser.add_argument("--output", default=OUTPUT_DIR, help="输出目录")
|
||||
parser.add_argument("--conf", type=float, default=CONF_THRESHOLD, help="置信度阈值")
|
||||
parser.add_argument("--iou", type=float, default=IOU_THRESHOLD, help="IoU 阈值")
|
||||
parser.add_argument("--device", default=DEVICE, help="设备: cuda:0, cpu")
|
||||
parser.add_argument("--view-img", action="store_true", help="显示图像")
|
||||
parser.add_argument("--save-txt", action="store_true", help="保存标签")
|
||||
parser.add_argument("--save-masks", action="store_true", help="保存掩码")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
run_segmentation_inference(
|
||||
model_path=opt.model,
|
||||
source=opt.source,
|
||||
output_dir=opt.output,
|
||||
conf_threshold=opt.conf,
|
||||
iou_threshold=opt.iou,
|
||||
device=opt.device,
|
||||
save_txt=opt.save_txt,
|
||||
save_masks=opt.save_masks,
|
||||
view_img=opt.view_img,
|
||||
line_width=LINE_WIDTH,
|
||||
)
|
||||
if __name__=="__main__":
|
||||
run_segmentation()
|
||||
|
||||
Reference in New Issue
Block a user