Files
zjsh_yolov11/推理图片反向上传CVAT/seg/yolo_seg_infer_vis—f.py

126 lines
3.0 KiB
Python
Raw Permalink Normal View History

2026-03-10 13:58:21 +08:00
import cv2
import numpy as np
import os
from ultralytics import YOLO
# ---------------- 配置 ----------------
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/61seg/exp3/weights/best.pt"
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/6111c/1"
OUT_DIR = "./labels" # 只保存 labels
IMG_SIZE = 640
CONF_THRES = 0.25
# 多边形简化比例(点数控制核心参数)
EPSILON_RATIO = 0.01
IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp")
# -------------------------------------
def simplify_polygon(poly, epsilon_ratio):
"""
使用 approxPolyDP 简化多边形
"""
poly = poly.astype(np.int32)
perimeter = cv2.arcLength(poly, True)
epsilon = epsilon_ratio * perimeter
approx = cv2.approxPolyDP(poly, epsilon, True)
return approx.reshape(-1, 2)
def extract_simplified_masks(result):
"""
提取并简化 YOLO mask
返回
[(cls_id, poly), ...]
"""
simplified = []
if result.masks is None:
return simplified
boxes = result.boxes
for i, poly in enumerate(result.masks.xy):
cls_id = int(boxes.cls[i])
conf = float(boxes.conf[i])
if conf < CONF_THRES:
continue
poly = simplify_polygon(poly, EPSILON_RATIO)
if len(poly) < 3:
continue
simplified.append((cls_id, poly))
return simplified
def save_yolo_seg_labels(masks, img_shape, save_path):
"""
保存 YOLO segmentation 标签
"""
h, w = img_shape[:2]
lines = []
for cls_id, poly in masks:
poly_norm = []
for x, y in poly:
poly_norm.append(f"{x / w:.6f}")
poly_norm.append(f"{y / h:.6f}")
lines.append(str(cls_id) + " " + " ".join(poly_norm))
# ⚠️ 没有目标也生成空 txtYOLO 训练需要)
with open(save_path, "w") as f:
if lines:
f.write("\n".join(lines))
def run_folder_inference():
out_lbl_dir = os.path.join(OUT_DIR, "labels")
os.makedirs(out_lbl_dir, exist_ok=True)
# 模型只加载一次
model = YOLO(MODEL_PATH)
img_files = sorted([
f for f in os.listdir(IMAGE_DIR)
if f.lower().endswith(IMG_EXTS)
])
print(f"📂 共检测 {len(img_files)} 张图片")
for idx, img_name in enumerate(img_files, 1):
img_path = os.path.join(IMAGE_DIR, img_name)
img = cv2.imread(img_path)
if img is None:
print(f"⚠️ 跳过无法读取: {img_name}")
continue
results = model(
img,
imgsz=IMG_SIZE,
conf=CONF_THRES,
verbose=False
)
result = results[0]
masks = extract_simplified_masks(result)
base_name = os.path.splitext(img_name)[0]
label_path = os.path.join(out_lbl_dir, base_name + ".txt")
save_yolo_seg_labels(masks, img.shape, label_path)
print(f"[{idx}/{len(img_files)}] ✅ {img_name}")
print("🎉 标签生成完成(仅保存 YOLO Seg 标签)")
if __name__ == "__main__":
run_folder_inference()