Files
zjsh_yolov11/yolo11_seg/yolo_seg_infer_vis—f.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

126 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
from ultralytics import YOLO
# ---------------- 配置 ----------------
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/61seg/exp2/weights/best.pt"
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/cls-61/class4c/12.17-18"
OUT_DIR = "./outputs" # 只保存 labels
IMG_SIZE = 640
CONF_THRES = 0.25
# 多边形简化比例(点数控制核心参数)
EPSILON_RATIO = 0.001
IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp")
# -------------------------------------
def simplify_polygon(poly, epsilon_ratio):
"""
使用 approxPolyDP 简化多边形
"""
poly = poly.astype(np.int32)
perimeter = cv2.arcLength(poly, True)
epsilon = epsilon_ratio * perimeter
approx = cv2.approxPolyDP(poly, epsilon, True)
return approx.reshape(-1, 2)
def extract_simplified_masks(result):
"""
提取并简化 YOLO mask
返回:
[(cls_id, poly), ...]
"""
simplified = []
if result.masks is None:
return simplified
boxes = result.boxes
for i, poly in enumerate(result.masks.xy):
cls_id = int(boxes.cls[i])
conf = float(boxes.conf[i])
if conf < CONF_THRES:
continue
poly = simplify_polygon(poly, EPSILON_RATIO)
if len(poly) < 3:
continue
simplified.append((cls_id, poly))
return simplified
def save_yolo_seg_labels(masks, img_shape, save_path):
"""
保存 YOLO segmentation 标签
"""
h, w = img_shape[:2]
lines = []
for cls_id, poly in masks:
poly_norm = []
for x, y in poly:
poly_norm.append(f"{x / w:.6f}")
poly_norm.append(f"{y / h:.6f}")
lines.append(str(cls_id) + " " + " ".join(poly_norm))
# ⚠️ 没有目标也生成空 txtYOLO 训练需要)
with open(save_path, "w") as f:
if lines:
f.write("\n".join(lines))
def run_folder_inference():
out_lbl_dir = os.path.join(OUT_DIR, "labels")
os.makedirs(out_lbl_dir, exist_ok=True)
# 模型只加载一次
model = YOLO(MODEL_PATH)
img_files = sorted([
f for f in os.listdir(IMAGE_DIR)
if f.lower().endswith(IMG_EXTS)
])
print(f"📂 共检测 {len(img_files)} 张图片")
for idx, img_name in enumerate(img_files, 1):
img_path = os.path.join(IMAGE_DIR, img_name)
img = cv2.imread(img_path)
if img is None:
print(f"⚠️ 跳过无法读取: {img_name}")
continue
results = model(
img,
imgsz=IMG_SIZE,
conf=CONF_THRES,
verbose=False
)
result = results[0]
masks = extract_simplified_masks(result)
base_name = os.path.splitext(img_name)[0]
label_path = os.path.join(out_lbl_dir, base_name + ".txt")
save_yolo_seg_labels(masks, img.shape, label_path)
print(f"[{idx}/{len(img_files)}] ✅ {img_name}")
print("🎉 标签生成完成(仅保存 YOLO Seg 标签)")
if __name__ == "__main__":
run_folder_inference()