部署文件+resize_cls+resize_seg

This commit is contained in:
琉璃月光
2025-09-11 20:44:35 +08:00
parent 471c718d42
commit a8d117af36
877 changed files with 1736 additions and 12534 deletions

299
yemian/resize/rtest.py Normal file
View File

@ -0,0 +1,299 @@
import os
import cv2
import torch
import argparse
import numpy as np
from ultralytics import YOLO
from pathlib import Path
# ====================== 配置参数 ======================
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/seg_r/exp/weights/best.pt"
#SOURCE_IMG_DIR = "/home/hx/yolo/output_masks" # 原始输入图像目录
SOURCE_IMG_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/f6" # 原始输入图像目录
OUTPUT_DIR = "/home/hx/yolo/output_masks2" # 推理输出根目录
ROI_COORDS_FILE = "./roi_coordinates/1_rois.txt" # 必须与训练时相同
CONF_THRESHOLD = 0.25
IOU_THRESHOLD = 0.45
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
SAVE_TXT = True
SAVE_MASKS = True
VIEW_IMG = False
LINE_WIDTH = 2
TARGET_SIZE = 640 # 必须与训练输入尺寸一致
def load_roi_coords(txt_path):
"""
加载 ROI 坐标 (x, y, w, h)
支持多个 ROI例如多视野检测
"""
rois = []
if not os.path.exists(txt_path):
raise FileNotFoundError(f"❌ ROI 文件未找到: {txt_path}")
with open(txt_path, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
try:
x, y, w, h = map(int, line.split(','))
rois.append((x, y, w, h))
print(f"📌 加载 ROI: (x={x}, y={y}, w={w}, h={h})")
except Exception as e:
print(f"⚠️ 无法解析 ROI 行: '{line}' | 错误: {e}")
return rois
def plot_result_with_opacity_and_roi(
result,
orig_img_shape,
roi_box,
line_width=2,
mask_opacity=0.5
):
"""
绘制分割结果,支持透明叠加,并适配 ROI → 原图坐标的还原
:param result: YOLO 推理结果(在 cropped-resized 图像上)
:param orig_img_shape: 原图 (H, W)
:param roi_box: (x, y, w, h) 当前 ROI 区域
:return: 在原图上绘制的结果图像
"""
h_orig, w_orig = orig_img_shape
img = np.zeros((h_orig, w_orig, 3), dtype=np.uint8) # 创建黑底原图大小画布
x, y, w, h = roi_box
# 获取模型输出的裁剪+resize后的图像上的结果
if result.masks is not None:
masks = result.masks.data.cpu().numpy() # (N, 640, 640)
# 将 mask resize 回 ROI 尺寸
resized_masks = []
for mask in masks:
m = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
m = (m > 0.5).astype(np.uint8)
resized_masks.append(m)
resized_masks = np.array(resized_masks)
# 上色叠加
overlay = img.copy()
colors = np.random.randint(0, 255, size=(len(resized_masks), 3), dtype=np.uint8)
for i, mask in enumerate(resized_masks):
color = colors[i].tolist()
roi_region = overlay[y:y+h, x:x+w]
roi_region[mask == 1] = color
cv2.addWeighted(overlay, mask_opacity, img, 1 - mask_opacity, 0, img)
# 绘制边界框和标签
if result.boxes is not None:
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy().astype(int)
confidences = result.boxes.conf.cpu().numpy()
colors = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
for i in range(len(boxes)):
# 映射回 ROI 像素坐标
box = boxes[i]
x1 = int(box[0] * w / TARGET_SIZE) + x
y1 = int(box[1] * h / TARGET_SIZE) + y
x2 = int(box[2] * w / TARGET_SIZE) + x
y2 = int(box[3] * h / TARGET_SIZE) + y
cls_id = classes[i]
conf = confidences[i]
color = colors[i].tolist()
cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width)
label = f"{cls_id} {conf:.2f}"
(text_w, text_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
cv2.rectangle(img, (x1, y1 - text_h - 6), (x1 + text_w, y1), color, -1)
cv2.putText(img, label, (x1, y1 - 4), font, font_scale,
(255, 255, 255), thickness, cv2.LINE_AA)
return img
def run_segmentation_inference_with_roi(
model_path,
source,
output_dir,
roi_coords_file,
conf_threshold=0.25,
iou_threshold=0.45,
device="cuda:0",
save_txt=True,
save_masks=True,
view_img=False,
line_width=2,
):
print(f"🚀 加载模型: {model_path}")
print(f"💻 使用设备: {device}")
# 加载模型
model = YOLO(model_path)
# 加载 ROI
rois = load_roi_coords(roi_coords_file)
if len(rois) == 0:
print("❌ 没有加载到任何有效的 ROI程序退出。")
return
# 创建输出目录
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
txt_dir = output_dir / "labels"
mask_dir = output_dir / "masks"
vis_dir = output_dir / "visualize"
txt_dir.mkdir(parents=True, exist_ok=True)
mask_dir.mkdir(parents=True, exist_ok=True)
vis_dir.mkdir(parents=True, exist_ok=True) # ✅ 改这里:无条件创建!
# 获取图像文件
source = Path(source)
if source.is_file():
img_files = [source]
else:
img_files = list(source.glob("*.jpg")) + \
list(source.glob("*.jpeg")) + \
list(source.glob("*.png")) + \
list(source.glob("*.bmp"))
if not img_files:
print(f"❌ 在 {source} 中未找到图像文件")
return
print(f"🖼️ 共 {len(img_files)} 张图片待推理...")
# 推理循环
for img_path in img_files:
print(f"\n🔍 正在处理: {img_path.name}")
orig_img = cv2.imread(str(img_path))
if orig_img is None:
print(f"❌ 无法读取图像: {img_path}")
continue
h_orig, w_orig = orig_img.shape[:2]
# 初始化全图级结果存储
full_vis_img = orig_img.copy() # 可视化用
all_txt_lines = []
# 对每个 ROI 进行推理
for idx, (x, y, w, h) in enumerate(rois):
# 越界检查
if x < 0 or y < 0 or x + w > w_orig or y + h > h_orig:
print(f"⚠️ ROI 越界,跳过: ({x},{y},{w},{h})")
continue
# 提取 ROI 并 resize 到模型输入尺寸
roi_img = orig_img[y:y+h, x:x+w]
if roi_img.size == 0:
print(f"⚠️ 空 ROI 区域: {idx}")
continue
resized_img = cv2.resize(roi_img, (TARGET_SIZE, TARGET_SIZE), interpolation=cv2.INTER_LINEAR)
# 推理
results = model(
source=resized_img,
conf=conf_threshold,
iou=iou_threshold,
imgsz=TARGET_SIZE,
device=device,
verbose=False
)
result = results[0]
# 可视化:将结果映射回原图
vis_part = plot_result_with_opacity_and_roi(
result=result,
orig_img_shape=(h_orig, w_orig),
roi_box=(x, y, w, h),
line_width=line_width,
mask_opacity=0.5
)
full_vis_img = cv2.addWeighted(full_vis_img, 1.0, vis_part, 0.7, 0)
# 保存标签YOLO 格式,归一化到原图)
if save_txt and result.masks is not None:
scale_x = w / TARGET_SIZE
scale_y = h / TARGET_SIZE
for i in range(len(result.boxes)):
cls_id = int(result.boxes.cls[i])
seg = result.masks.xy[i] # (N, 2) in 640 space
# 映射回 ROI 像素坐标
seg[:, 0] = seg[:, 0] * scale_x + x
seg[:, 1] = seg[:, 1] * scale_y + y
# 归一化到原图
seg_norm = seg / [w_orig, h_orig]
seg_flat = seg_norm.flatten()
line = f"{cls_id} {' '.join(f'{val:.6f}' for val in seg_flat)}"
all_txt_lines.append(line)
# 保存单个 ROI 的 mask可选
if save_masks and result.masks is not None:
mask_data = result.masks.data.cpu().numpy()
combined_mask = (mask_data.sum(axis=0) > 0).astype(np.uint8) * 255
# resize 回 ROI 尺寸
roi_mask = cv2.resize(combined_mask, (w, h), interpolation=cv2.INTER_NEAREST)
full_mask = np.zeros((h_orig, w_orig), dtype=np.uint8)
full_mask[y:y+h, x:x+w] = roi_mask
mask_save_path = mask_dir / f"mask_{img_path.stem}_roi{idx}.png"
cv2.imwrite(str(mask_save_path), full_mask)
# 保存最终可视化图像
vis_save_path = vis_dir / f"vis_{img_path.name}"
cv2.imwrite(str(vis_save_path), full_vis_img)
print(f"✅ 保存可视化结果: {vis_save_path}")
# 保存合并的文本标签
if save_txt and all_txt_lines:
txt_save_path = txt_dir / f"{img_path.stem}.txt"
with open(txt_save_path, 'w') as f:
f.write("\n".join(all_txt_lines) + "\n")
print(f"📝 保存标签: {txt_save_path}")
# 实时显示
if view_img:
display_img = cv2.resize(full_vis_img, (960, 540))
cv2.imshow("Segmentation Result", display_img)
if cv2.waitKey(0) == 27: # ESC
cv2.destroyAllWindows()
break
if view_img:
cv2.destroyAllWindows()
print(f"\n🎉 推理完成!结果保存在: {output_dir}")
# ====================== 主程序 ======================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=MODEL_PATH, help="模型权重路径")
parser.add_argument("--source", default=SOURCE_IMG_DIR, help="图片路径或文件夹")
parser.add_argument("--output", default=OUTPUT_DIR, help="输出目录")
parser.add_argument("--roi-file", default=ROI_COORDS_FILE, help="ROI 坐标文件路径")
parser.add_argument("--conf", type=float, default=CONF_THRESHOLD, help="置信度阈值")
parser.add_argument("--iou", type=float, default=IOU_THRESHOLD, help="IoU 阈值")
parser.add_argument("--device", default=DEVICE, help="设备: cuda:0, cpu")
parser.add_argument("--view-img", action="store_true", help="显示图像")
parser.add_argument("--save-txt", action="store_true", help="保存标签")
parser.add_argument("--save-masks", action="store_true", help="保存掩码")
opt = parser.parse_args()
run_segmentation_inference_with_roi(
model_path=opt.model,
source=opt.source,
output_dir=opt.output,
roi_coords_file=opt.roi_file,
conf_threshold=opt.conf,
iou_threshold=opt.iou,
device=opt.device,
save_txt=opt.save_txt or SAVE_TXT,
save_masks=opt.save_masks or SAVE_MASKS,
view_img=opt.view_img,
line_width=LINE_WIDTH,
)