142 lines
3.1 KiB
Python
142 lines
3.1 KiB
Python
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
import os
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
# ---------------- 配置 ----------------
|
|||
|
|
MODEL_PATH = "60seg.pt"
|
|||
|
|
IMAGE_PATH = "3.png"
|
|||
|
|
OUT_DIR = "outputs"
|
|||
|
|
IMG_SIZE = 640
|
|||
|
|
CONF_THRES = 0.25
|
|||
|
|
ALPHA = 0.5
|
|||
|
|
# -------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_color(idx):
|
|||
|
|
np.random.seed(idx)
|
|||
|
|
return tuple(int(x) for x in np.random.randint(0, 255, 3))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def draw_segmentation(frame, result):
|
|||
|
|
"""
|
|||
|
|
原样保留你的可视化逻辑
|
|||
|
|
"""
|
|||
|
|
overlay = frame.copy()
|
|||
|
|
|
|||
|
|
if result.masks is None:
|
|||
|
|
return frame
|
|||
|
|
|
|||
|
|
boxes = result.boxes
|
|||
|
|
names = result.names
|
|||
|
|
|
|||
|
|
for i, poly in enumerate(result.masks.xy):
|
|||
|
|
cls_id = int(boxes.cls[i])
|
|||
|
|
conf = float(boxes.conf[i])
|
|||
|
|
|
|||
|
|
if conf < CONF_THRES:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
color = get_color(cls_id)
|
|||
|
|
poly = poly.astype(np.int32)
|
|||
|
|
|
|||
|
|
# 填充 mask
|
|||
|
|
cv2.fillPoly(overlay, [poly], color)
|
|||
|
|
|
|||
|
|
# 轮廓
|
|||
|
|
cv2.polylines(overlay, [poly], True, color, 2)
|
|||
|
|
|
|||
|
|
# 标签
|
|||
|
|
x, y = poly[0]
|
|||
|
|
label = f"{names[cls_id]} {conf:.2f}"
|
|||
|
|
cv2.putText(
|
|||
|
|
overlay,
|
|||
|
|
label,
|
|||
|
|
(x, max(y - 5, 20)),
|
|||
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|||
|
|
0.6,
|
|||
|
|
color,
|
|||
|
|
2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return cv2.addWeighted(overlay, ALPHA, frame, 1 - ALPHA, 0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_masks_as_yolo_seg(result, img_shape, save_path):
|
|||
|
|
"""
|
|||
|
|
保存 YOLO segmentation 标注(txt)
|
|||
|
|
格式:
|
|||
|
|
class_id x1 y1 x2 y2 ...(全部归一化)
|
|||
|
|
"""
|
|||
|
|
if result.masks is None:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
h, w = img_shape[:2]
|
|||
|
|
boxes = result.boxes
|
|||
|
|
|
|||
|
|
lines = []
|
|||
|
|
|
|||
|
|
for i, poly in enumerate(result.masks.xy):
|
|||
|
|
cls_id = int(boxes.cls[i])
|
|||
|
|
conf = float(boxes.conf[i])
|
|||
|
|
|
|||
|
|
if conf < CONF_THRES:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
poly_norm = []
|
|||
|
|
for x, y in poly:
|
|||
|
|
poly_norm.append(f"{x / w:.6f}")
|
|||
|
|
poly_norm.append(f"{y / h:.6f}")
|
|||
|
|
|
|||
|
|
line = str(cls_id) + " " + " ".join(poly_norm)
|
|||
|
|
lines.append(line)
|
|||
|
|
|
|||
|
|
if lines:
|
|||
|
|
with open(save_path, "w") as f:
|
|||
|
|
f.write("\n".join(lines))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_image_inference():
|
|||
|
|
os.makedirs(OUT_DIR, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
model = YOLO(MODEL_PATH)
|
|||
|
|
|
|||
|
|
# 读取图片
|
|||
|
|
img = cv2.imread(IMAGE_PATH)
|
|||
|
|
if img is None:
|
|||
|
|
raise FileNotFoundError(f"❌ 无法读取图片: {IMAGE_PATH}")
|
|||
|
|
|
|||
|
|
# 推理
|
|||
|
|
results = model(
|
|||
|
|
img,
|
|||
|
|
imgsz=IMG_SIZE,
|
|||
|
|
conf=CONF_THRES,
|
|||
|
|
verbose=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = results[0]
|
|||
|
|
|
|||
|
|
# 1️⃣ 保存可视化结果
|
|||
|
|
vis = draw_segmentation(img, result)
|
|||
|
|
out_img_path = os.path.join(
|
|||
|
|
OUT_DIR,
|
|||
|
|
os.path.splitext(os.path.basename(IMAGE_PATH))[0] + "_seg.png"
|
|||
|
|
)
|
|||
|
|
cv2.imwrite(out_img_path, vis)
|
|||
|
|
|
|||
|
|
# 2️⃣ 保存 YOLO segmentation 标注
|
|||
|
|
out_txt_path = os.path.join(
|
|||
|
|
OUT_DIR,
|
|||
|
|
os.path.splitext(os.path.basename(IMAGE_PATH))[0] + ".txt"
|
|||
|
|
)
|
|||
|
|
save_masks_as_yolo_seg(result, img.shape, out_txt_path)
|
|||
|
|
|
|||
|
|
print("✅ 推理完成")
|
|||
|
|
print(f"🖼 可视化结果: {out_img_path}")
|
|||
|
|
print(f"📄 YOLO Seg 标注: {out_txt_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
run_image_inference()
|