142 lines
3.1 KiB
Python
142 lines
3.1 KiB
Python
import cv2
|
||
import numpy as np
|
||
import os
|
||
from ultralytics import YOLO
|
||
|
||
# ---------------- 配置 ----------------
|
||
MODEL_PATH = "60seg.pt"
|
||
IMAGE_PATH = "3.png"
|
||
OUT_DIR = "outputs"
|
||
IMG_SIZE = 640
|
||
CONF_THRES = 0.25
|
||
ALPHA = 0.5
|
||
# -------------------------------------
|
||
|
||
|
||
def get_color(idx):
|
||
np.random.seed(idx)
|
||
return tuple(int(x) for x in np.random.randint(0, 255, 3))
|
||
|
||
|
||
def draw_segmentation(frame, result):
|
||
"""
|
||
原样保留你的可视化逻辑
|
||
"""
|
||
overlay = frame.copy()
|
||
|
||
if result.masks is None:
|
||
return frame
|
||
|
||
boxes = result.boxes
|
||
names = result.names
|
||
|
||
for i, poly in enumerate(result.masks.xy):
|
||
cls_id = int(boxes.cls[i])
|
||
conf = float(boxes.conf[i])
|
||
|
||
if conf < CONF_THRES:
|
||
continue
|
||
|
||
color = get_color(cls_id)
|
||
poly = poly.astype(np.int32)
|
||
|
||
# 填充 mask
|
||
cv2.fillPoly(overlay, [poly], color)
|
||
|
||
# 轮廓
|
||
cv2.polylines(overlay, [poly], True, color, 2)
|
||
|
||
# 标签
|
||
x, y = poly[0]
|
||
label = f"{names[cls_id]} {conf:.2f}"
|
||
cv2.putText(
|
||
overlay,
|
||
label,
|
||
(x, max(y - 5, 20)),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
0.6,
|
||
color,
|
||
2
|
||
)
|
||
|
||
return cv2.addWeighted(overlay, ALPHA, frame, 1 - ALPHA, 0)
|
||
|
||
|
||
def save_masks_as_yolo_seg(result, img_shape, save_path):
|
||
"""
|
||
保存 YOLO segmentation 标注(txt)
|
||
格式:
|
||
class_id x1 y1 x2 y2 ...(全部归一化)
|
||
"""
|
||
if result.masks is None:
|
||
return
|
||
|
||
h, w = img_shape[:2]
|
||
boxes = result.boxes
|
||
|
||
lines = []
|
||
|
||
for i, poly in enumerate(result.masks.xy):
|
||
cls_id = int(boxes.cls[i])
|
||
conf = float(boxes.conf[i])
|
||
|
||
if conf < CONF_THRES:
|
||
continue
|
||
|
||
poly_norm = []
|
||
for x, y in poly:
|
||
poly_norm.append(f"{x / w:.6f}")
|
||
poly_norm.append(f"{y / h:.6f}")
|
||
|
||
line = str(cls_id) + " " + " ".join(poly_norm)
|
||
lines.append(line)
|
||
|
||
if lines:
|
||
with open(save_path, "w") as f:
|
||
f.write("\n".join(lines))
|
||
|
||
|
||
def run_image_inference():
|
||
os.makedirs(OUT_DIR, exist_ok=True)
|
||
|
||
# 加载模型
|
||
model = YOLO(MODEL_PATH)
|
||
|
||
# 读取图片
|
||
img = cv2.imread(IMAGE_PATH)
|
||
if img is None:
|
||
raise FileNotFoundError(f"❌ 无法读取图片: {IMAGE_PATH}")
|
||
|
||
# 推理
|
||
results = model(
|
||
img,
|
||
imgsz=IMG_SIZE,
|
||
conf=CONF_THRES,
|
||
verbose=False
|
||
)
|
||
|
||
result = results[0]
|
||
|
||
# 1️⃣ 保存可视化结果
|
||
vis = draw_segmentation(img, result)
|
||
out_img_path = os.path.join(
|
||
OUT_DIR,
|
||
os.path.splitext(os.path.basename(IMAGE_PATH))[0] + "_seg.png"
|
||
)
|
||
cv2.imwrite(out_img_path, vis)
|
||
|
||
# 2️⃣ 保存 YOLO segmentation 标注
|
||
out_txt_path = os.path.join(
|
||
OUT_DIR,
|
||
os.path.splitext(os.path.basename(IMAGE_PATH))[0] + ".txt"
|
||
)
|
||
save_masks_as_yolo_seg(result, img.shape, out_txt_path)
|
||
|
||
print("✅ 推理完成")
|
||
print(f"🖼 可视化结果: {out_img_path}")
|
||
print(f"📄 YOLO Seg 标注: {out_txt_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
run_image_inference()
|