Files
2025-11-03 16:10:50 +08:00

96 lines
2.9 KiB
Python

import torch
import torch.nn as nn
from ultralytics import YOLO
# ------------------- 核心剪枝函数 -------------------
def prune_conv_bn(conv_bn, keep_idx):
"""剪枝 ConvBNAct 模块的 Conv + BN"""
conv = conv_bn.conv
bn = conv_bn.bn
# 跳过 depthwise
if conv.groups != 1:
return conv_bn
# 剪枝 conv
new_conv = nn.Conv2d(
in_channels=conv.in_channels,
out_channels=len(keep_idx),
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=(conv.bias is not None)
).to(conv.weight.device)
new_conv.weight.data = conv.weight.data[keep_idx].clone()
if conv.bias is not None:
new_conv.bias.data = conv.bias.data[keep_idx].clone()
# 剪枝 BN
if bn is not None:
new_bn = nn.BatchNorm2d(len(keep_idx)).to(bn.weight.device)
new_bn.weight.data = bn.weight.data[keep_idx].clone()
new_bn.bias.data = bn.bias.data[keep_idx].clone()
new_bn.running_mean = bn.running_mean[keep_idx].clone()
new_bn.running_var = bn.running_var[keep_idx].clone()
else:
new_bn = None
# 替换模块
conv_bn.conv = new_conv
conv_bn.bn = new_bn
return conv_bn
def get_prune_idx(conv_bn, prune_ratio=0.3):
"""根据 BN gamma 或 L2 norm 计算要保留的通道索引"""
conv = conv_bn.conv
bn = conv_bn.bn
if bn is not None:
gamma = bn.weight.data.abs()
else:
gamma = conv.weight.data.view(conv.out_channels, -1).norm(p=2, dim=1)
keep_num = max(int(conv.out_channels * (1 - prune_ratio)), 1)
_, idxs = torch.topk(gamma, keep_num)
return idxs
def prune_yolov11_model(model, prune_ratio=0.3):
"""遍历 YOLO 模型,剪枝所有 ConvBNAct"""
for name, m in model.named_modules():
if m.__class__.__name__ == "ConvBNAct":
keep_idx = get_prune_idx(m, prune_ratio)
prune_conv_bn(m, keep_idx)
return model
# ------------------- 主流程 -------------------
def main(model_path="best.pt", save_path="yolov11_pruned_ts.pt",
prune_ratio=0.3, device="cuda"):
# 加载 YOLO 模型
model = YOLO(model_path).model
model.eval().to(device)
# 剪枝
print(f"✅ 开始剪枝,比例: {prune_ratio}")
model = prune_yolov11_model(model, prune_ratio)
print("✅ 剪枝完成")
# 构造 dummy 输入
example_inputs = torch.randn(1, 3, 640, 640).to(device)
# TorchScript 跟踪
print("🔹 开始 TorchScript 跟踪...")
traced_model = torch.jit.trace(model, example_inputs)
traced_model = torch.jit.optimize_for_inference(traced_model)
# 保存 TorchScript 模型
traced_model.save(save_path)
print(f"✅ TorchScript 剪枝模型已保存: {save_path}")
if __name__ == "__main__":
main(
model_path="best.pt",
save_path="yolov11_pruned_ts.pt",
prune_ratio=0.3
)