96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
# ------------------- 核心剪枝函数 -------------------
|
||
|
|
def prune_conv_bn(conv_bn, keep_idx):
|
||
|
|
"""剪枝 ConvBNAct 模块的 Conv + BN"""
|
||
|
|
conv = conv_bn.conv
|
||
|
|
bn = conv_bn.bn
|
||
|
|
|
||
|
|
# 跳过 depthwise
|
||
|
|
if conv.groups != 1:
|
||
|
|
return conv_bn
|
||
|
|
|
||
|
|
# 剪枝 conv
|
||
|
|
new_conv = nn.Conv2d(
|
||
|
|
in_channels=conv.in_channels,
|
||
|
|
out_channels=len(keep_idx),
|
||
|
|
kernel_size=conv.kernel_size,
|
||
|
|
stride=conv.stride,
|
||
|
|
padding=conv.padding,
|
||
|
|
dilation=conv.dilation,
|
||
|
|
groups=conv.groups,
|
||
|
|
bias=(conv.bias is not None)
|
||
|
|
).to(conv.weight.device)
|
||
|
|
new_conv.weight.data = conv.weight.data[keep_idx].clone()
|
||
|
|
if conv.bias is not None:
|
||
|
|
new_conv.bias.data = conv.bias.data[keep_idx].clone()
|
||
|
|
|
||
|
|
# 剪枝 BN
|
||
|
|
if bn is not None:
|
||
|
|
new_bn = nn.BatchNorm2d(len(keep_idx)).to(bn.weight.device)
|
||
|
|
new_bn.weight.data = bn.weight.data[keep_idx].clone()
|
||
|
|
new_bn.bias.data = bn.bias.data[keep_idx].clone()
|
||
|
|
new_bn.running_mean = bn.running_mean[keep_idx].clone()
|
||
|
|
new_bn.running_var = bn.running_var[keep_idx].clone()
|
||
|
|
else:
|
||
|
|
new_bn = None
|
||
|
|
|
||
|
|
# 替换模块
|
||
|
|
conv_bn.conv = new_conv
|
||
|
|
conv_bn.bn = new_bn
|
||
|
|
return conv_bn
|
||
|
|
|
||
|
|
def get_prune_idx(conv_bn, prune_ratio=0.3):
|
||
|
|
"""根据 BN gamma 或 L2 norm 计算要保留的通道索引"""
|
||
|
|
conv = conv_bn.conv
|
||
|
|
bn = conv_bn.bn
|
||
|
|
if bn is not None:
|
||
|
|
gamma = bn.weight.data.abs()
|
||
|
|
else:
|
||
|
|
gamma = conv.weight.data.view(conv.out_channels, -1).norm(p=2, dim=1)
|
||
|
|
keep_num = max(int(conv.out_channels * (1 - prune_ratio)), 1)
|
||
|
|
_, idxs = torch.topk(gamma, keep_num)
|
||
|
|
return idxs
|
||
|
|
|
||
|
|
def prune_yolov11_model(model, prune_ratio=0.3):
|
||
|
|
"""遍历 YOLO 模型,剪枝所有 ConvBNAct"""
|
||
|
|
for name, m in model.named_modules():
|
||
|
|
if m.__class__.__name__ == "ConvBNAct":
|
||
|
|
keep_idx = get_prune_idx(m, prune_ratio)
|
||
|
|
prune_conv_bn(m, keep_idx)
|
||
|
|
return model
|
||
|
|
|
||
|
|
# ------------------- 主流程 -------------------
|
||
|
|
def main(model_path="best.pt", save_path="yolov11_pruned_ts.pt",
|
||
|
|
prune_ratio=0.3, device="cuda"):
|
||
|
|
|
||
|
|
# 加载 YOLO 模型
|
||
|
|
model = YOLO(model_path).model
|
||
|
|
model.eval().to(device)
|
||
|
|
|
||
|
|
# 剪枝
|
||
|
|
print(f"✅ 开始剪枝,比例: {prune_ratio}")
|
||
|
|
model = prune_yolov11_model(model, prune_ratio)
|
||
|
|
print("✅ 剪枝完成")
|
||
|
|
|
||
|
|
# 构造 dummy 输入
|
||
|
|
example_inputs = torch.randn(1, 3, 640, 640).to(device)
|
||
|
|
|
||
|
|
# TorchScript 跟踪
|
||
|
|
print("🔹 开始 TorchScript 跟踪...")
|
||
|
|
traced_model = torch.jit.trace(model, example_inputs)
|
||
|
|
traced_model = torch.jit.optimize_for_inference(traced_model)
|
||
|
|
|
||
|
|
# 保存 TorchScript 模型
|
||
|
|
traced_model.save(save_path)
|
||
|
|
print(f"✅ TorchScript 剪枝模型已保存: {save_path}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main(
|
||
|
|
model_path="best.pt",
|
||
|
|
save_path="yolov11_pruned_ts.pt",
|
||
|
|
prune_ratio=0.3
|
||
|
|
)
|