rknn替换,板子是3568的
This commit is contained in:
BIN
ailai_pc/jianzhi/best.pt
Normal file
BIN
ailai_pc/jianzhi/best.pt
Normal file
Binary file not shown.
BIN
ailai_pc/jianzhi/jz.pt
Normal file
BIN
ailai_pc/jianzhi/jz.pt
Normal file
Binary file not shown.
95
ailai_pc/jianzhi/jz.py
Normal file
95
ailai_pc/jianzhi/jz.py
Normal file
@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ultralytics import YOLO
|
||||
|
||||
# ------------------- 核心剪枝函数 -------------------
|
||||
def prune_conv_bn(conv_bn, keep_idx):
|
||||
"""剪枝 ConvBNAct 模块的 Conv + BN"""
|
||||
conv = conv_bn.conv
|
||||
bn = conv_bn.bn
|
||||
|
||||
# 跳过 depthwise
|
||||
if conv.groups != 1:
|
||||
return conv_bn
|
||||
|
||||
# 剪枝 conv
|
||||
new_conv = nn.Conv2d(
|
||||
in_channels=conv.in_channels,
|
||||
out_channels=len(keep_idx),
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=(conv.bias is not None)
|
||||
).to(conv.weight.device)
|
||||
new_conv.weight.data = conv.weight.data[keep_idx].clone()
|
||||
if conv.bias is not None:
|
||||
new_conv.bias.data = conv.bias.data[keep_idx].clone()
|
||||
|
||||
# 剪枝 BN
|
||||
if bn is not None:
|
||||
new_bn = nn.BatchNorm2d(len(keep_idx)).to(bn.weight.device)
|
||||
new_bn.weight.data = bn.weight.data[keep_idx].clone()
|
||||
new_bn.bias.data = bn.bias.data[keep_idx].clone()
|
||||
new_bn.running_mean = bn.running_mean[keep_idx].clone()
|
||||
new_bn.running_var = bn.running_var[keep_idx].clone()
|
||||
else:
|
||||
new_bn = None
|
||||
|
||||
# 替换模块
|
||||
conv_bn.conv = new_conv
|
||||
conv_bn.bn = new_bn
|
||||
return conv_bn
|
||||
|
||||
def get_prune_idx(conv_bn, prune_ratio=0.3):
|
||||
"""根据 BN gamma 或 L2 norm 计算要保留的通道索引"""
|
||||
conv = conv_bn.conv
|
||||
bn = conv_bn.bn
|
||||
if bn is not None:
|
||||
gamma = bn.weight.data.abs()
|
||||
else:
|
||||
gamma = conv.weight.data.view(conv.out_channels, -1).norm(p=2, dim=1)
|
||||
keep_num = max(int(conv.out_channels * (1 - prune_ratio)), 1)
|
||||
_, idxs = torch.topk(gamma, keep_num)
|
||||
return idxs
|
||||
|
||||
def prune_yolov11_model(model, prune_ratio=0.3):
|
||||
"""遍历 YOLO 模型,剪枝所有 ConvBNAct"""
|
||||
for name, m in model.named_modules():
|
||||
if m.__class__.__name__ == "ConvBNAct":
|
||||
keep_idx = get_prune_idx(m, prune_ratio)
|
||||
prune_conv_bn(m, keep_idx)
|
||||
return model
|
||||
|
||||
# ------------------- 主流程 -------------------
|
||||
def main(model_path="best.pt", save_path="yolov11_pruned_ts.pt",
|
||||
prune_ratio=0.3, device="cuda"):
|
||||
|
||||
# 加载 YOLO 模型
|
||||
model = YOLO(model_path).model
|
||||
model.eval().to(device)
|
||||
|
||||
# 剪枝
|
||||
print(f"✅ 开始剪枝,比例: {prune_ratio}")
|
||||
model = prune_yolov11_model(model, prune_ratio)
|
||||
print("✅ 剪枝完成")
|
||||
|
||||
# 构造 dummy 输入
|
||||
example_inputs = torch.randn(1, 3, 640, 640).to(device)
|
||||
|
||||
# TorchScript 跟踪
|
||||
print("🔹 开始 TorchScript 跟踪...")
|
||||
traced_model = torch.jit.trace(model, example_inputs)
|
||||
traced_model = torch.jit.optimize_for_inference(traced_model)
|
||||
|
||||
# 保存 TorchScript 模型
|
||||
traced_model.save(save_path)
|
||||
print(f"✅ TorchScript 剪枝模型已保存: {save_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(
|
||||
model_path="best.pt",
|
||||
save_path="yolov11_pruned_ts.pt",
|
||||
prune_ratio=0.3
|
||||
)
|
||||
BIN
ailai_pc/jianzhi/yolov11_pruned.pt
Normal file
BIN
ailai_pc/jianzhi/yolov11_pruned.pt
Normal file
Binary file not shown.
Reference in New Issue
Block a user