Files
zjsh_yolov11/yolo11_seg/yolo_seg_infer_vis—60f.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

167 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
from ultralytics import YOLO
# ================= 配置 =================
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/60seg/exp3/weights/best.pt"
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/分割60/class4/1"
OUT_DIR = "./outputs"
IMG_SIZE = 640
CONF_THRES = 0.25
# 多边形简化比例(点数控制核心参数)
EPSILON_RATIO = 0.001
IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp")
# -------- 保存开关 --------
SAVE_LABELS = True # 是否保存 YOLO seg 标签
SAVE_VIS = True # 是否保存可视化结果
# ======================================
def simplify_polygon(poly, epsilon_ratio):
"""使用 approxPolyDP 简化多边形"""
poly = poly.astype(np.int32)
perimeter = cv2.arcLength(poly, True)
epsilon = epsilon_ratio * perimeter
approx = cv2.approxPolyDP(poly, epsilon, True)
return approx.reshape(-1, 2)
def extract_simplified_masks(result):
"""
提取并简化 YOLO mask
返回: [(cls_id, poly), ...]
"""
simplified = []
if result.masks is None:
return simplified
boxes = result.boxes
for i, poly in enumerate(result.masks.xy):
cls_id = int(boxes.cls[i])
conf = float(boxes.conf[i])
if conf < CONF_THRES:
continue
poly = simplify_polygon(poly, EPSILON_RATIO)
if len(poly) < 3:
continue
simplified.append((cls_id, poly))
return simplified
def save_yolo_seg_labels(masks, img_shape, save_path):
"""保存 YOLO segmentation 标签(无目标也生成空 txt"""
h, w = img_shape[:2]
lines = []
for cls_id, poly in masks:
poly_norm = []
for x, y in poly:
poly_norm.append(f"{x / w:.6f}")
poly_norm.append(f"{y / h:.6f}")
lines.append(str(cls_id) + " " + " ".join(poly_norm))
with open(save_path, "w") as f:
if lines:
f.write("\n".join(lines))
def draw_polygons(img, masks):
"""在图像上绘制 segmentation 多边形"""
vis = img.copy()
for cls_id, poly in masks:
poly = poly.astype(np.int32)
cv2.polylines(
vis,
[poly],
isClosed=True,
color=(0, 255, 0),
thickness=2
)
x, y = poly[0]
cv2.putText(
vis,
str(cls_id),
(int(x), int(y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 0),
2
)
return vis
def run_folder_inference():
# 输出目录
out_lbl_dir = os.path.join(OUT_DIR, "labels")
out_img_dir = os.path.join(OUT_DIR, "images")
if SAVE_LABELS:
os.makedirs(out_lbl_dir, exist_ok=True)
if SAVE_VIS:
os.makedirs(out_img_dir, exist_ok=True)
# 加载模型(只一次)
model = YOLO(MODEL_PATH)
img_files = sorted([
f for f in os.listdir(IMAGE_DIR)
if f.lower().endswith(IMG_EXTS)
])
print(f"📂 共检测 {len(img_files)} 张图片")
for idx, img_name in enumerate(img_files, 1):
img_path = os.path.join(IMAGE_DIR, img_name)
img = cv2.imread(img_path)
if img is None:
print(f"⚠️ 跳过无法读取: {img_name}")
continue
results = model(
img,
imgsz=IMG_SIZE,
conf=CONF_THRES,
verbose=False
)
result = results[0]
masks = extract_simplified_masks(result)
base_name = os.path.splitext(img_name)[0]
# ---------- 保存标签 ----------
if SAVE_LABELS:
label_path = os.path.join(out_lbl_dir, base_name + ".txt")
save_yolo_seg_labels(masks, img.shape, label_path)
# ---------- 保存可视化 ----------
if SAVE_VIS:
vis_img = draw_polygons(img, masks)
vis_path = os.path.join(out_img_dir, img_name)
cv2.imwrite(vis_path, vis_img)
print(f"[{idx}/{len(img_files)}] ✅ {img_name}")
print("🎉 推理完成")
if __name__ == "__main__":
run_folder_inference()