121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
|
|
# onnx_segmentation_demo.py
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
import onnxruntime as ort
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
|
|||
|
|
|
|||
|
|
def preprocess(image_path, input_size=(640, 640)):
|
|||
|
|
"""图像预处理:resize, normalize, HWC → CHW"""
|
|||
|
|
img = cv2.imread(image_path)
|
|||
|
|
img_resized = cv2.resize(img, input_size) # resize
|
|||
|
|
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
|||
|
|
|
|||
|
|
# 归一化: [0,255] → [0,1] 或 ImageNet: (x - mean)/std
|
|||
|
|
img_norm = img_rgb.astype(np.float32) / 255.0
|
|||
|
|
# img_norm = (img_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5 # 如果是 [-1,1]
|
|||
|
|
|
|||
|
|
# HWC → CHW → 1xCxHxW
|
|||
|
|
img_chw = np.transpose(img_norm, (2, 0, 1))
|
|||
|
|
input_tensor = np.expand_dims(img_chw, axis=0)
|
|||
|
|
|
|||
|
|
return input_tensor, img_rgb # 返回输入张量和原始图像(用于显示)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def postprocess(output, original_shape, num_classes=1, threshold=0.5):
|
|||
|
|
"""
|
|||
|
|
后处理 ONNX 输出
|
|||
|
|
:param output: ONNX 推理输出 (C, H, W) 或 (1, C, H, W)
|
|||
|
|
:param original_shape: 原图 (H, W, C)
|
|||
|
|
:param num_classes: 类别数(1: 二值分割;N: 多类分割)
|
|||
|
|
:param threshold: 二值化阈值
|
|||
|
|
:return: mask (H, W) 或 (H, W, C)
|
|||
|
|
"""
|
|||
|
|
# 假设输出 shape: (1, C, H, W) 或 (C, H, W)
|
|||
|
|
if output.ndim == 4:
|
|||
|
|
output = output[0] # 去掉 batch 维度
|
|||
|
|
|
|||
|
|
h, w = original_shape[:2]
|
|||
|
|
|
|||
|
|
if num_classes == 1:
|
|||
|
|
# 二值分割: Sigmoid 输出
|
|||
|
|
prob_map = output[0] # (H, W)
|
|||
|
|
mask = (prob_map > threshold).astype(np.uint8) * 255
|
|||
|
|
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
|||
|
|
return mask
|
|||
|
|
else:
|
|||
|
|
# 多类分割: Softmax 输出,shape (C, H, W)
|
|||
|
|
pred = np.argmax(output, axis=0) # (H, W)
|
|||
|
|
mask = pred.astype(np.uint8)
|
|||
|
|
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
|||
|
|
return mask
|
|||
|
|
|
|||
|
|
|
|||
|
|
def visualize_segmentation(image, mask, alpha=0.5, colormap=cv2.COLORMAP_JET):
|
|||
|
|
"""叠加 mask 到原图上"""
|
|||
|
|
if len(mask.shape) == 2 and mask.max() <= 1:
|
|||
|
|
mask = (mask * 255).astype(np.uint8)
|
|||
|
|
|
|||
|
|
# 彩色化 mask
|
|||
|
|
colored_mask = cv2.applyColorMap(mask, colormap)
|
|||
|
|
# 融合
|
|||
|
|
overlay = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
|
|||
|
|
return overlay
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ==================== 主程序 ====================
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
ONNX_MODEL = "best.onnx"
|
|||
|
|
TEST_IMAGE = "1.jpg"
|
|||
|
|
INPUT_SIZE = (640, 640)
|
|||
|
|
NUM_CLASSES = 1 # 1: 二值分割;>1: 多类分割
|
|||
|
|
THRESHOLD = 0.8
|
|||
|
|
|
|||
|
|
# 1. 加载 ONNX 模型
|
|||
|
|
print("📦 加载 ONNX 模型...")
|
|||
|
|
ort_session = ort.InferenceSession(ONNX_MODEL, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
|||
|
|
|
|||
|
|
# 获取输入名称
|
|||
|
|
input_name = ort_session.get_inputs()[0].name
|
|||
|
|
|
|||
|
|
# 2. 预处理
|
|||
|
|
input_tensor, original_image = preprocess(TEST_IMAGE, INPUT_SIZE)
|
|||
|
|
print(f"📥 输入形状: {input_tensor.shape}")
|
|||
|
|
|
|||
|
|
# 3. 推理
|
|||
|
|
print("🚀 开始推理...")
|
|||
|
|
outputs = ort_session.run(None, {input_name: input_tensor})
|
|||
|
|
logits = outputs[0] # 假设只有一个输出
|
|||
|
|
print(f"📤 输出形状: {logits.shape}")
|
|||
|
|
|
|||
|
|
# 4. 后处理
|
|||
|
|
mask = postprocess(logits, original_image.shape, num_classes=NUM_CLASSES, threshold=THRESHOLD)
|
|||
|
|
|
|||
|
|
# 5. 可视化
|
|||
|
|
overlay = visualize_segmentation(original_image, mask, alpha=0.6)
|
|||
|
|
|
|||
|
|
# 6. 显示
|
|||
|
|
plt.figure(figsize=(12, 6))
|
|||
|
|
plt.subplot(1, 3, 1)
|
|||
|
|
plt.title("原始图像")
|
|||
|
|
plt.imshow(original_image)
|
|||
|
|
plt.axis("off")
|
|||
|
|
|
|||
|
|
plt.subplot(1, 3, 2)
|
|||
|
|
plt.title("预测 Mask")
|
|||
|
|
plt.imshow(mask, cmap="gray")
|
|||
|
|
plt.axis("off")
|
|||
|
|
|
|||
|
|
plt.subplot(1, 3, 3)
|
|||
|
|
plt.title("叠加效果")
|
|||
|
|
plt.imshow(overlay)
|
|||
|
|
plt.axis("off")
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.show()
|
|||
|
|
|
|||
|
|
# (可选)保存结果
|
|||
|
|
# cv2.imwrite("overlay.jpg", cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
|
|||
|
|
print("✅ 验证完成!")
|