import torch import torch.nn as nn from ultralytics import YOLO # ------------------- 核心剪枝函数 ------------------- def prune_conv_bn(conv_bn, keep_idx): """剪枝 ConvBNAct 模块的 Conv + BN""" conv = conv_bn.conv bn = conv_bn.bn # 跳过 depthwise if conv.groups != 1: return conv_bn # 剪枝 conv new_conv = nn.Conv2d( in_channels=conv.in_channels, out_channels=len(keep_idx), kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=(conv.bias is not None) ).to(conv.weight.device) new_conv.weight.data = conv.weight.data[keep_idx].clone() if conv.bias is not None: new_conv.bias.data = conv.bias.data[keep_idx].clone() # 剪枝 BN if bn is not None: new_bn = nn.BatchNorm2d(len(keep_idx)).to(bn.weight.device) new_bn.weight.data = bn.weight.data[keep_idx].clone() new_bn.bias.data = bn.bias.data[keep_idx].clone() new_bn.running_mean = bn.running_mean[keep_idx].clone() new_bn.running_var = bn.running_var[keep_idx].clone() else: new_bn = None # 替换模块 conv_bn.conv = new_conv conv_bn.bn = new_bn return conv_bn def get_prune_idx(conv_bn, prune_ratio=0.3): """根据 BN gamma 或 L2 norm 计算要保留的通道索引""" conv = conv_bn.conv bn = conv_bn.bn if bn is not None: gamma = bn.weight.data.abs() else: gamma = conv.weight.data.view(conv.out_channels, -1).norm(p=2, dim=1) keep_num = max(int(conv.out_channels * (1 - prune_ratio)), 1) _, idxs = torch.topk(gamma, keep_num) return idxs def prune_yolov11_model(model, prune_ratio=0.3): """遍历 YOLO 模型,剪枝所有 ConvBNAct""" for name, m in model.named_modules(): if m.__class__.__name__ == "ConvBNAct": keep_idx = get_prune_idx(m, prune_ratio) prune_conv_bn(m, keep_idx) return model # ------------------- 主流程 ------------------- def main(model_path="best.pt", save_path="yolov11_pruned_ts.pt", prune_ratio=0.3, device="cuda"): # 加载 YOLO 模型 model = YOLO(model_path).model model.eval().to(device) # 剪枝 print(f"✅ 开始剪枝,比例: {prune_ratio}") model = prune_yolov11_model(model, prune_ratio) print("✅ 剪枝完成") # 构造 dummy 输入 example_inputs = torch.randn(1, 3, 640, 640).to(device) # TorchScript 跟踪 print("🔹 开始 TorchScript 跟踪...") traced_model = torch.jit.trace(model, example_inputs) traced_model = torch.jit.optimize_for_inference(traced_model) # 保存 TorchScript 模型 traced_model.save(save_path) print(f"✅ TorchScript 剪枝模型已保存: {save_path}") if __name__ == "__main__": main( model_path="best.pt", save_path="yolov11_pruned_ts.pt", prune_ratio=0.3 )