Files
zjsh_yolov11/yemian/resize/rtest.py
琉璃月光 df7c0730f5 bushu
2025-10-21 14:11:52 +08:00

299 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import torch
import argparse
import numpy as np
from ultralytics import YOLO
from pathlib import Path
# ====================== 配置参数 ======================3
MODEL_PATH = "best.pt"
#SOURCE_IMG_DIR = "/home/hx/yolo/output_masks" # 原始输入图像目录
SOURCE_IMG_DIR = "/home/hx/yolo/yemian/test_image" # 原始输入图像目录
OUTPUT_DIR = "/home/hx/yolo/output_masks2" # 推理输出根目录
ROI_COORDS_FILE = "./roi_coordinates/1_rois2.txt" # 必须与训练时相同
CONF_THRESHOLD = 0.25
IOU_THRESHOLD = 0.45
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
SAVE_TXT = True
SAVE_MASKS = True
VIEW_IMG = False
LINE_WIDTH = 2
TARGET_SIZE = 640 # 必须与训练输入尺寸一致
def load_roi_coords(txt_path):
"""
加载 ROI 坐标 (x, y, w, h)
支持多个 ROI例如多视野检测
"""
rois = []
if not os.path.exists(txt_path):
raise FileNotFoundError(f"❌ ROI 文件未找到: {txt_path}")
with open(txt_path, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
try:
x, y, w, h = map(int, line.split(','))
rois.append((x, y, w, h))
print(f"📌 加载 ROI: (x={x}, y={y}, w={w}, h={h})")
except Exception as e:
print(f"⚠️ 无法解析 ROI 行: '{line}' | 错误: {e}")
return rois
def plot_result_with_opacity_and_roi(
result,
orig_img_shape,
roi_box,
line_width=2,
mask_opacity=0.5
):
"""
绘制分割结果,支持透明叠加,并适配 ROI → 原图坐标的还原
:param result: YOLO 推理结果(在 cropped-resized 图像上)
:param orig_img_shape: 原图 (H, W)
:param roi_box: (x, y, w, h) 当前 ROI 区域
:return: 在原图上绘制的结果图像
"""
h_orig, w_orig = orig_img_shape
img = np.zeros((h_orig, w_orig, 3), dtype=np.uint8) # 创建黑底原图大小画布
x, y, w, h = roi_box
# 获取模型输出的裁剪+resize后的图像上的结果
if result.masks is not None:
masks = result.masks.data.cpu().numpy() # (N, 640, 640)
# 将 mask resize 回 ROI 尺寸
resized_masks = []
for mask in masks:
m = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
m = (m > 0.5).astype(np.uint8)
resized_masks.append(m)
resized_masks = np.array(resized_masks)
# 上色叠加
overlay = img.copy()
colors = np.random.randint(0, 255, size=(len(resized_masks), 3), dtype=np.uint8)
for i, mask in enumerate(resized_masks):
color = colors[i].tolist()
roi_region = overlay[y:y+h, x:x+w]
roi_region[mask == 1] = color
cv2.addWeighted(overlay, mask_opacity, img, 1 - mask_opacity, 0, img)
# 绘制边界框和标签
if result.boxes is not None:
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy().astype(int)
confidences = result.boxes.conf.cpu().numpy()
colors = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
for i in range(len(boxes)):
# 映射回 ROI 像素坐标
box = boxes[i]
x1 = int(box[0] * w / TARGET_SIZE) + x
y1 = int(box[1] * h / TARGET_SIZE) + y
x2 = int(box[2] * w / TARGET_SIZE) + x
y2 = int(box[3] * h / TARGET_SIZE) + y
cls_id = classes[i]
conf = confidences[i]
color = colors[i].tolist()
cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width)
label = f"{cls_id} {conf:.2f}"
(text_w, text_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
cv2.rectangle(img, (x1, y1 - text_h - 6), (x1 + text_w, y1), color, -1)
cv2.putText(img, label, (x1, y1 - 4), font, font_scale,
(255, 255, 255), thickness, cv2.LINE_AA)
return img
def run_segmentation_inference_with_roi(
model_path,
source,
output_dir,
roi_coords_file,
conf_threshold=0.25,
iou_threshold=0.45,
device="cuda:0",
save_txt=True,
save_masks=True,
view_img=False,
line_width=2,
):
print(f"🚀 加载模型: {model_path}")
print(f"💻 使用设备: {device}")
# 加载模型
model = YOLO(model_path)
# 加载 ROI
rois = load_roi_coords(roi_coords_file)
if len(rois) == 0:
print("❌ 没有加载到任何有效的 ROI程序退出。")
return
# 创建输出目录
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
txt_dir = output_dir / "labels"
mask_dir = output_dir / "masks"
vis_dir = output_dir / "visualize"
txt_dir.mkdir(parents=True, exist_ok=True)
mask_dir.mkdir(parents=True, exist_ok=True)
vis_dir.mkdir(parents=True, exist_ok=True) # ✅ 改这里:无条件创建!
# 获取图像文件
source = Path(source)
if source.is_file():
img_files = [source]
else:
img_files = list(source.glob("*.jpg")) + \
list(source.glob("*.jpeg")) + \
list(source.glob("*.png")) + \
list(source.glob("*.bmp"))
if not img_files:
print(f"❌ 在 {source} 中未找到图像文件")
return
print(f"🖼️ 共 {len(img_files)} 张图片待推理...")
# 推理循环
for img_path in img_files:
print(f"\n🔍 正在处理: {img_path.name}")
orig_img = cv2.imread(str(img_path))
if orig_img is None:
print(f"❌ 无法读取图像: {img_path}")
continue
h_orig, w_orig = orig_img.shape[:2]
# 初始化全图级结果存储
full_vis_img = orig_img.copy() # 可视化用
all_txt_lines = []
# 对每个 ROI 进行推理
for idx, (x, y, w, h) in enumerate(rois):
# 越界检查
if x < 0 or y < 0 or x + w > w_orig or y + h > h_orig:
print(f"⚠️ ROI 越界,跳过: ({x},{y},{w},{h})")
continue
# 提取 ROI 并 resize 到模型输入尺寸
roi_img = orig_img[y:y+h, x:x+w]
if roi_img.size == 0:
print(f"⚠️ 空 ROI 区域: {idx}")
continue
resized_img = cv2.resize(roi_img, (TARGET_SIZE, TARGET_SIZE), interpolation=cv2.INTER_LINEAR)
# 推理
results = model(
source=resized_img,
conf=conf_threshold,
iou=iou_threshold,
imgsz=TARGET_SIZE,
device=device,
verbose=False
)
result = results[0]
# 可视化:将结果映射回原图
vis_part = plot_result_with_opacity_and_roi(
result=result,
orig_img_shape=(h_orig, w_orig),
roi_box=(x, y, w, h),
line_width=line_width,
mask_opacity=0.5
)
full_vis_img = cv2.addWeighted(full_vis_img, 1.0, vis_part, 0.7, 0)
# 保存标签YOLO 格式,归一化到原图)
if save_txt and result.masks is not None:
scale_x = w / TARGET_SIZE
scale_y = h / TARGET_SIZE
for i in range(len(result.boxes)):
cls_id = int(result.boxes.cls[i])
seg = result.masks.xy[i] # (N, 2) in 640 space
# 映射回 ROI 像素坐标
seg[:, 0] = seg[:, 0] * scale_x + x
seg[:, 1] = seg[:, 1] * scale_y + y
# 归一化到原图
seg_norm = seg / [w_orig, h_orig]
seg_flat = seg_norm.flatten()
line = f"{cls_id} {' '.join(f'{val:.6f}' for val in seg_flat)}"
all_txt_lines.append(line)
# 保存单个 ROI 的 mask可选
if save_masks and result.masks is not None:
mask_data = result.masks.data.cpu().numpy()
combined_mask = (mask_data.sum(axis=0) > 0).astype(np.uint8) * 255
# resize 回 ROI 尺寸
roi_mask = cv2.resize(combined_mask, (w, h), interpolation=cv2.INTER_NEAREST)
full_mask = np.zeros((h_orig, w_orig), dtype=np.uint8)
full_mask[y:y+h, x:x+w] = roi_mask
mask_save_path = mask_dir / f"mask_{img_path.stem}_roi{idx}.png"
cv2.imwrite(str(mask_save_path), full_mask)
# 保存最终可视化图像
vis_save_path = vis_dir / f"vis_{img_path.name}"
cv2.imwrite(str(vis_save_path), full_vis_img)
print(f"✅ 保存可视化结果: {vis_save_path}")
# 保存合并的文本标签
if save_txt and all_txt_lines:
txt_save_path = txt_dir / f"{img_path.stem}.txt"
with open(txt_save_path, 'w') as f:
f.write("\n".join(all_txt_lines) + "\n")
print(f"📝 保存标签: {txt_save_path}")
# 实时显示
if view_img:
display_img = cv2.resize(full_vis_img, (960, 540))
cv2.imshow("Segmentation Result", display_img)
if cv2.waitKey(0) == 27: # ESC
cv2.destroyAllWindows()
break
if view_img:
cv2.destroyAllWindows()
print(f"\n🎉 推理完成!结果保存在: {output_dir}")
# ====================== 主程序 ======================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=MODEL_PATH, help="模型权重路径")
parser.add_argument("--source", default=SOURCE_IMG_DIR, help="图片路径或文件夹")
parser.add_argument("--output", default=OUTPUT_DIR, help="输出目录")
parser.add_argument("--roi-file", default=ROI_COORDS_FILE, help="ROI 坐标文件路径")
parser.add_argument("--conf", type=float, default=CONF_THRESHOLD, help="置信度阈值")
parser.add_argument("--iou", type=float, default=IOU_THRESHOLD, help="IoU 阈值")
parser.add_argument("--device", default=DEVICE, help="设备: cuda:0, cpu")
parser.add_argument("--view-img", action="store_true", help="显示图像")
parser.add_argument("--save-txt", action="store_true", help="保存标签")
parser.add_argument("--save-masks", action="store_true", help="保存掩码")
opt = parser.parse_args()
run_segmentation_inference_with_roi(
model_path=opt.model,
source=opt.source,
output_dir=opt.output,
roi_coords_file=opt.roi_file,
conf_threshold=opt.conf,
iou_threshold=opt.iou,
device=opt.device,
save_txt=opt.save_txt or SAVE_TXT,
save_masks=opt.save_masks or SAVE_MASKS,
view_img=opt.view_img,
line_width=LINE_WIDTH,
)