Files
zjsh_yolov11/onnx_val/onnx_seg_val.py

121 lines
3.7 KiB
Python
Raw Permalink Normal View History

2025-09-05 14:29:33 +08:00
# onnx_segmentation_demo.py
import cv2
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
def preprocess(image_path, input_size=(640, 640)):
"""图像预处理resize, normalize, HWC → CHW"""
img = cv2.imread(image_path)
img_resized = cv2.resize(img, input_size) # resize
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
# 归一化: [0,255] → [0,1] 或 ImageNet: (x - mean)/std
img_norm = img_rgb.astype(np.float32) / 255.0
# img_norm = (img_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5 # 如果是 [-1,1]
# HWC → CHW → 1xCxHxW
img_chw = np.transpose(img_norm, (2, 0, 1))
input_tensor = np.expand_dims(img_chw, axis=0)
return input_tensor, img_rgb # 返回输入张量和原始图像(用于显示)
def postprocess(output, original_shape, num_classes=1, threshold=0.5):
"""
后处理 ONNX 输出
:param output: ONNX 推理输出 (C, H, W) (1, C, H, W)
:param original_shape: 原图 (H, W, C)
:param num_classes: 类别数1: 二值分割N: 多类分割
:param threshold: 二值化阈值
:return: mask (H, W) (H, W, C)
"""
# 假设输出 shape: (1, C, H, W) 或 (C, H, W)
if output.ndim == 4:
output = output[0] # 去掉 batch 维度
h, w = original_shape[:2]
if num_classes == 1:
# 二值分割: Sigmoid 输出
prob_map = output[0] # (H, W)
mask = (prob_map > threshold).astype(np.uint8) * 255
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
return mask
else:
# 多类分割: Softmax 输出shape (C, H, W)
pred = np.argmax(output, axis=0) # (H, W)
mask = pred.astype(np.uint8)
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
return mask
def visualize_segmentation(image, mask, alpha=0.5, colormap=cv2.COLORMAP_JET):
"""叠加 mask 到原图上"""
if len(mask.shape) == 2 and mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
# 彩色化 mask
colored_mask = cv2.applyColorMap(mask, colormap)
# 融合
overlay = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
return overlay
# ==================== 主程序 ====================
if __name__ == "__main__":
ONNX_MODEL = "best.onnx"
TEST_IMAGE = "1.jpg"
INPUT_SIZE = (640, 640)
NUM_CLASSES = 1 # 1: 二值分割;>1: 多类分割
THRESHOLD = 0.8
# 1. 加载 ONNX 模型
print("📦 加载 ONNX 模型...")
ort_session = ort.InferenceSession(ONNX_MODEL, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# 获取输入名称
input_name = ort_session.get_inputs()[0].name
# 2. 预处理
input_tensor, original_image = preprocess(TEST_IMAGE, INPUT_SIZE)
print(f"📥 输入形状: {input_tensor.shape}")
# 3. 推理
print("🚀 开始推理...")
outputs = ort_session.run(None, {input_name: input_tensor})
logits = outputs[0] # 假设只有一个输出
print(f"📤 输出形状: {logits.shape}")
# 4. 后处理
mask = postprocess(logits, original_image.shape, num_classes=NUM_CLASSES, threshold=THRESHOLD)
# 5. 可视化
overlay = visualize_segmentation(original_image, mask, alpha=0.6)
# 6. 显示
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.title("原始图像")
plt.imshow(original_image)
plt.axis("off")
plt.subplot(1, 3, 2)
plt.title("预测 Mask")
plt.imshow(mask, cmap="gray")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.title("叠加效果")
plt.imshow(overlay)
plt.axis("off")
plt.tight_layout()
plt.show()
# (可选)保存结果
# cv2.imwrite("overlay.jpg", cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
print("✅ 验证完成!")