167 lines
4.1 KiB
Python
167 lines
4.1 KiB
Python
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
import os
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
# ================= 配置 =================
|
|||
|
|
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/60seg/exp3/weights/best.pt"
|
|||
|
|
IMAGE_DIR = "/media/hx/04e879fa-d697-4b02-ac7e-a4148876ebb0/dataset/1/分割60/class4/1"
|
|||
|
|
OUT_DIR = "./outputs"
|
|||
|
|
|
|||
|
|
IMG_SIZE = 640
|
|||
|
|
CONF_THRES = 0.25
|
|||
|
|
|
|||
|
|
# 多边形简化比例(点数控制核心参数)
|
|||
|
|
EPSILON_RATIO = 0.001
|
|||
|
|
|
|||
|
|
IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp")
|
|||
|
|
|
|||
|
|
# -------- 保存开关 --------
|
|||
|
|
SAVE_LABELS = True # 是否保存 YOLO seg 标签
|
|||
|
|
SAVE_VIS = True # 是否保存可视化结果
|
|||
|
|
# ======================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def simplify_polygon(poly, epsilon_ratio):
|
|||
|
|
"""使用 approxPolyDP 简化多边形"""
|
|||
|
|
poly = poly.astype(np.int32)
|
|||
|
|
perimeter = cv2.arcLength(poly, True)
|
|||
|
|
epsilon = epsilon_ratio * perimeter
|
|||
|
|
approx = cv2.approxPolyDP(poly, epsilon, True)
|
|||
|
|
return approx.reshape(-1, 2)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_simplified_masks(result):
|
|||
|
|
"""
|
|||
|
|
提取并简化 YOLO mask
|
|||
|
|
返回: [(cls_id, poly), ...]
|
|||
|
|
"""
|
|||
|
|
simplified = []
|
|||
|
|
|
|||
|
|
if result.masks is None:
|
|||
|
|
return simplified
|
|||
|
|
|
|||
|
|
boxes = result.boxes
|
|||
|
|
|
|||
|
|
for i, poly in enumerate(result.masks.xy):
|
|||
|
|
cls_id = int(boxes.cls[i])
|
|||
|
|
conf = float(boxes.conf[i])
|
|||
|
|
|
|||
|
|
if conf < CONF_THRES:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
poly = simplify_polygon(poly, EPSILON_RATIO)
|
|||
|
|
|
|||
|
|
if len(poly) < 3:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
simplified.append((cls_id, poly))
|
|||
|
|
|
|||
|
|
return simplified
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_yolo_seg_labels(masks, img_shape, save_path):
|
|||
|
|
"""保存 YOLO segmentation 标签(无目标也生成空 txt)"""
|
|||
|
|
h, w = img_shape[:2]
|
|||
|
|
lines = []
|
|||
|
|
|
|||
|
|
for cls_id, poly in masks:
|
|||
|
|
poly_norm = []
|
|||
|
|
for x, y in poly:
|
|||
|
|
poly_norm.append(f"{x / w:.6f}")
|
|||
|
|
poly_norm.append(f"{y / h:.6f}")
|
|||
|
|
|
|||
|
|
lines.append(str(cls_id) + " " + " ".join(poly_norm))
|
|||
|
|
|
|||
|
|
with open(save_path, "w") as f:
|
|||
|
|
if lines:
|
|||
|
|
f.write("\n".join(lines))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def draw_polygons(img, masks):
|
|||
|
|
"""在图像上绘制 segmentation 多边形"""
|
|||
|
|
vis = img.copy()
|
|||
|
|
|
|||
|
|
for cls_id, poly in masks:
|
|||
|
|
poly = poly.astype(np.int32)
|
|||
|
|
|
|||
|
|
cv2.polylines(
|
|||
|
|
vis,
|
|||
|
|
[poly],
|
|||
|
|
isClosed=True,
|
|||
|
|
color=(0, 255, 0),
|
|||
|
|
thickness=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
x, y = poly[0]
|
|||
|
|
cv2.putText(
|
|||
|
|
vis,
|
|||
|
|
str(cls_id),
|
|||
|
|
(int(x), int(y)),
|
|||
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|||
|
|
0.6,
|
|||
|
|
(0, 255, 0),
|
|||
|
|
2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return vis
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_folder_inference():
|
|||
|
|
# 输出目录
|
|||
|
|
out_lbl_dir = os.path.join(OUT_DIR, "labels")
|
|||
|
|
out_img_dir = os.path.join(OUT_DIR, "images")
|
|||
|
|
|
|||
|
|
if SAVE_LABELS:
|
|||
|
|
os.makedirs(out_lbl_dir, exist_ok=True)
|
|||
|
|
if SAVE_VIS:
|
|||
|
|
os.makedirs(out_img_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 加载模型(只一次)
|
|||
|
|
model = YOLO(MODEL_PATH)
|
|||
|
|
|
|||
|
|
img_files = sorted([
|
|||
|
|
f for f in os.listdir(IMAGE_DIR)
|
|||
|
|
if f.lower().endswith(IMG_EXTS)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
print(f"📂 共检测 {len(img_files)} 张图片")
|
|||
|
|
|
|||
|
|
for idx, img_name in enumerate(img_files, 1):
|
|||
|
|
img_path = os.path.join(IMAGE_DIR, img_name)
|
|||
|
|
img = cv2.imread(img_path)
|
|||
|
|
|
|||
|
|
if img is None:
|
|||
|
|
print(f"⚠️ 跳过无法读取: {img_name}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
results = model(
|
|||
|
|
img,
|
|||
|
|
imgsz=IMG_SIZE,
|
|||
|
|
conf=CONF_THRES,
|
|||
|
|
verbose=False
|
|||
|
|
)
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
masks = extract_simplified_masks(result)
|
|||
|
|
base_name = os.path.splitext(img_name)[0]
|
|||
|
|
|
|||
|
|
# ---------- 保存标签 ----------
|
|||
|
|
if SAVE_LABELS:
|
|||
|
|
label_path = os.path.join(out_lbl_dir, base_name + ".txt")
|
|||
|
|
save_yolo_seg_labels(masks, img.shape, label_path)
|
|||
|
|
|
|||
|
|
# ---------- 保存可视化 ----------
|
|||
|
|
if SAVE_VIS:
|
|||
|
|
vis_img = draw_polygons(img, masks)
|
|||
|
|
vis_path = os.path.join(out_img_dir, img_name)
|
|||
|
|
cv2.imwrite(vis_path, vis_img)
|
|||
|
|
|
|||
|
|
print(f"[{idx}/{len(img_files)}] ✅ {img_name}")
|
|||
|
|
|
|||
|
|
print("🎉 推理完成")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
run_folder_inference()
|