Files
zjsh_yolov11/onnx_val/onnx_seg_val.py
2025-09-05 14:29:33 +08:00

121 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# onnx_segmentation_demo.py
import cv2
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
def preprocess(image_path, input_size=(640, 640)):
"""图像预处理resize, normalize, HWC → CHW"""
img = cv2.imread(image_path)
img_resized = cv2.resize(img, input_size) # resize
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
# 归一化: [0,255] → [0,1] 或 ImageNet: (x - mean)/std
img_norm = img_rgb.astype(np.float32) / 255.0
# img_norm = (img_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5 # 如果是 [-1,1]
# HWC → CHW → 1xCxHxW
img_chw = np.transpose(img_norm, (2, 0, 1))
input_tensor = np.expand_dims(img_chw, axis=0)
return input_tensor, img_rgb # 返回输入张量和原始图像(用于显示)
def postprocess(output, original_shape, num_classes=1, threshold=0.5):
"""
后处理 ONNX 输出
:param output: ONNX 推理输出 (C, H, W) 或 (1, C, H, W)
:param original_shape: 原图 (H, W, C)
:param num_classes: 类别数1: 二值分割N: 多类分割)
:param threshold: 二值化阈值
:return: mask (H, W) 或 (H, W, C)
"""
# 假设输出 shape: (1, C, H, W) 或 (C, H, W)
if output.ndim == 4:
output = output[0] # 去掉 batch 维度
h, w = original_shape[:2]
if num_classes == 1:
# 二值分割: Sigmoid 输出
prob_map = output[0] # (H, W)
mask = (prob_map > threshold).astype(np.uint8) * 255
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
return mask
else:
# 多类分割: Softmax 输出shape (C, H, W)
pred = np.argmax(output, axis=0) # (H, W)
mask = pred.astype(np.uint8)
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
return mask
def visualize_segmentation(image, mask, alpha=0.5, colormap=cv2.COLORMAP_JET):
"""叠加 mask 到原图上"""
if len(mask.shape) == 2 and mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
# 彩色化 mask
colored_mask = cv2.applyColorMap(mask, colormap)
# 融合
overlay = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
return overlay
# ==================== 主程序 ====================
if __name__ == "__main__":
ONNX_MODEL = "best.onnx"
TEST_IMAGE = "1.jpg"
INPUT_SIZE = (640, 640)
NUM_CLASSES = 1 # 1: 二值分割;>1: 多类分割
THRESHOLD = 0.8
# 1. 加载 ONNX 模型
print("📦 加载 ONNX 模型...")
ort_session = ort.InferenceSession(ONNX_MODEL, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# 获取输入名称
input_name = ort_session.get_inputs()[0].name
# 2. 预处理
input_tensor, original_image = preprocess(TEST_IMAGE, INPUT_SIZE)
print(f"📥 输入形状: {input_tensor.shape}")
# 3. 推理
print("🚀 开始推理...")
outputs = ort_session.run(None, {input_name: input_tensor})
logits = outputs[0] # 假设只有一个输出
print(f"📤 输出形状: {logits.shape}")
# 4. 后处理
mask = postprocess(logits, original_image.shape, num_classes=NUM_CLASSES, threshold=THRESHOLD)
# 5. 可视化
overlay = visualize_segmentation(original_image, mask, alpha=0.6)
# 6. 显示
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.title("原始图像")
plt.imshow(original_image)
plt.axis("off")
plt.subplot(1, 3, 2)
plt.title("预测 Mask")
plt.imshow(mask, cmap="gray")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.title("叠加效果")
plt.imshow(overlay)
plt.axis("off")
plt.tight_layout()
plt.show()
# (可选)保存结果
# cv2.imwrite("overlay.jpg", cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
print("✅ 验证完成!")