97 lines
2.3 KiB
Python
97 lines
2.3 KiB
Python
import cv2
|
||
import numpy as np
|
||
import os
|
||
from ultralytics import YOLO
|
||
|
||
# ---------------- 配置 ----------------
|
||
MODEL_PATH = "seg.pt"
|
||
IMAGE_PATH = "1.png" # 支持任意路径
|
||
IMG_SIZE = 640
|
||
CONF_THRES = 0.25
|
||
ALPHA = 0.5
|
||
# -------------------------------------
|
||
|
||
|
||
def get_color(idx):
|
||
np.random.seed(idx)
|
||
return tuple(int(x) for x in np.random.randint(0, 255, 3))
|
||
|
||
|
||
def draw_segmentation(frame, result):
|
||
"""
|
||
仅填充 mask(无标签、无轮廓、无线段)
|
||
返回每个 mask 的类别名和水平宽度(max_x - min_x)
|
||
"""
|
||
overlay = frame.copy()
|
||
|
||
if result.masks is None:
|
||
return frame, []
|
||
|
||
boxes = result.boxes
|
||
widths = []
|
||
|
||
for i, poly in enumerate(result.masks.xy):
|
||
cls_id = int(boxes.cls[i])
|
||
conf = float(boxes.conf[i])
|
||
|
||
if conf < CONF_THRES:
|
||
continue
|
||
|
||
color = get_color(cls_id)
|
||
poly = poly.astype(np.int32)
|
||
|
||
# 计算水平宽度:最右 x - 最左 x
|
||
min_x = np.min(poly[:, 0])
|
||
max_x = np.max(poly[:, 0])
|
||
width = max_x - min_x
|
||
|
||
widths.append((result.names[cls_id], width))
|
||
|
||
# 仅填充 mask(无其他绘制)
|
||
cv2.fillPoly(overlay, [poly], color)
|
||
|
||
blended = cv2.addWeighted(overlay, ALPHA, frame, 1 - ALPHA, 0)
|
||
return blended, widths
|
||
|
||
|
||
def run_image_inference():
|
||
# 加载模型
|
||
model = YOLO(MODEL_PATH)
|
||
|
||
# 读取图片
|
||
img = cv2.imread(IMAGE_PATH)
|
||
if img is None:
|
||
raise FileNotFoundError(f"无法读取图片: {IMAGE_PATH}")
|
||
|
||
print(f"📷 正在处理: {IMAGE_PATH}")
|
||
|
||
# 推理
|
||
results = model(
|
||
img,
|
||
imgsz=IMG_SIZE,
|
||
conf=CONF_THRES,
|
||
verbose=False
|
||
)
|
||
result = results[0]
|
||
|
||
# 生成可视化图(仅 mask 填充)
|
||
vis, widths = draw_segmentation(img, result)
|
||
|
||
# 保存到原图所在目录
|
||
base_name = os.path.splitext(IMAGE_PATH)[0]
|
||
out_path = base_name + "_seg.png"
|
||
cv2.imwrite(out_path, vis)
|
||
|
||
# 打印宽度信息(仅文本,不画图)
|
||
print("\nMask 水平宽度 (像素):")
|
||
if widths:
|
||
for name, width in widths:
|
||
print(f" • {name}: {width:.1f}")
|
||
else:
|
||
print(" (无有效 mask)")
|
||
|
||
print(f"\n完成!结果已保存至:\n {out_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
run_image_inference() |