Files
zjsh_yolov11/yolo11_seg/yolo_seg_infer_vis61.py
琉璃月光 67883f1a50 最新推送
2026-03-10 14:14:14 +08:00

142 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
from ultralytics import YOLO
# ---------------- 配置 ----------------
MODEL_PATH = "/home/hx/yolo/ultralytics_yolo11-main/runs/train/61seg/exp/weights/best.pt"
IMAGE_PATH = "2.png"
OUT_DIR = "outputs"
IMG_SIZE = 640
CONF_THRES = 0.25
ALPHA = 0.5
# -------------------------------------
def get_color(idx):
np.random.seed(idx)
return tuple(int(x) for x in np.random.randint(0, 255, 3))
def draw_segmentation(frame, result):
"""
原样保留你的可视化逻辑
"""
overlay = frame.copy()
if result.masks is None:
return frame
boxes = result.boxes
names = result.names
for i, poly in enumerate(result.masks.xy):
cls_id = int(boxes.cls[i])
conf = float(boxes.conf[i])
if conf < CONF_THRES:
continue
color = get_color(cls_id)
poly = poly.astype(np.int32)
# 填充 mask
cv2.fillPoly(overlay, [poly], color)
# 轮廓
cv2.polylines(overlay, [poly], True, color, 2)
# 标签
x, y = poly[0]
label = f"{names[cls_id]} {conf:.2f}"
cv2.putText(
overlay,
label,
(x, max(y - 5, 20)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
color,
2
)
return cv2.addWeighted(overlay, ALPHA, frame, 1 - ALPHA, 0)
def save_masks_as_yolo_seg(result, img_shape, save_path):
"""
保存 YOLO segmentation 标注txt
格式:
class_id x1 y1 x2 y2 ...(全部归一化)
"""
if result.masks is None:
return
h, w = img_shape[:2]
boxes = result.boxes
lines = []
for i, poly in enumerate(result.masks.xy):
cls_id = int(boxes.cls[i])
conf = float(boxes.conf[i])
if conf < CONF_THRES:
continue
poly_norm = []
for x, y in poly:
poly_norm.append(f"{x / w:.6f}")
poly_norm.append(f"{y / h:.6f}")
line = str(cls_id) + " " + " ".join(poly_norm)
lines.append(line)
if lines:
with open(save_path, "w") as f:
f.write("\n".join(lines))
def run_image_inference():
os.makedirs(OUT_DIR, exist_ok=True)
# 加载模型
model = YOLO(MODEL_PATH)
# 读取图片
img = cv2.imread(IMAGE_PATH)
if img is None:
raise FileNotFoundError(f"❌ 无法读取图片: {IMAGE_PATH}")
# 推理
results = model(
img,
imgsz=IMG_SIZE,
conf=CONF_THRES,
verbose=False
)
result = results[0]
# 1⃣ 保存可视化结果
vis = draw_segmentation(img, result)
out_img_path = os.path.join(
OUT_DIR,
os.path.splitext(os.path.basename(IMAGE_PATH))[0] + "_seg.png"
)
cv2.imwrite(out_img_path, vis)
# 2⃣ 保存 YOLO segmentation 标注
out_txt_path = os.path.join(
OUT_DIR,
os.path.splitext(os.path.basename(IMAGE_PATH))[0] + ".txt"
)
save_masks_as_yolo_seg(result, img.shape, out_txt_path)
print("✅ 推理完成")
print(f"🖼 可视化结果: {out_img_path}")
print(f"📄 YOLO Seg 标注: {out_txt_path}")
if __name__ == "__main__":
run_image_inference()