LeNet(
(conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
在pytorch里面, prune通过对权重进行掩码来完成. 这个如何理解?
首先, 我们打印一下原始的conv1的权重看看:
module = model.conv1
print(list(module.named_modules()))
print(list(module.named_buffers()))
print(list(module.named_parameters()))
这里列举了后面可能会用到三个方法, 这个可以查看当前的module到底是一个啥情况。
着重观察 named_parameters 因为参数都保存在这里, 打印完了之后可以看到:
[('weight', Parameter containing:
tensor([[[[-0.2312, 0.2133, -0.1313],
[-0.2980, -0.1838, -0.2902],
[-0.3006, 0.1338, -0.0980]]],
[[[-0.1239, 0.1060, 0.3271],
[-0.0301, -0.0245, 0.0493],
[-0.0160, 0.0397, -0.1242]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.2593, -0.0520, 0.0303, 0.0382, -0.0468, -0.1053], device='cuda:0',
requires_grad=True))]
它有一个weight和一个bias, 这没错, 合乎常理. 我们甚至可以看看weights的尺寸是多少.
for a in module.named_parameters():
print(a[1].shape)
输出:
torch.Size([6, 1, 3, 3])
torch.Size([6])
这个其实就是卷积的维度了, 6指的是channel, 1值得还是stride, 3指的是kernel size. 然后重点来了, 要开始做prune了, 在pytorch里面操作也很简单, 只需要一行代码:
import torch.nn.utils.prune as prune
prune.random_unstructured(module, name='weight', amount=0.3)
这个可以从众多的剪枝方法中, 选择一个很好的手段来完成同样的目的. 然后我们再打印一下named_parameters:
[('bias', Parameter containing:
tensor([-0.2281, 0.3085, 0.0937, -0.0540, 0.3295, 0.1107], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1934, -0.0172, -0.1957],
[ 0.1655, 0.1669, -0.2448],
[-0.2250, -0.0963, -0.0195]]],
[[[-0.3154, 0.1868, 0.0103],
[-0.2245, 0.1548, 0.2567],
[ 0.0713, 0.1262, 0.1547]]]], device='cuda:0', requires_grad=True))]
唯一的变化就是 weights 变成了 weights_orig, prune之后通过掩码的方式存放在了 named_buffers里面:
print(list(module.named_buffers()))
可以看到:
[('weight_mask', tensor([[[[1., 1., 1.],
[0., 1., 0.],
[0., 1., 1.]]],
[[[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.]]]], device='cuda:0'))]
那么问题来了, 只是把权重进行了掩码, 那么我要知道你剪掉了哪几个channel怎么办? 而且你这个是剪的权重, 结构呢? 我怎么把这个结构找出来??
所以说这只是第一步, 接下来我们来看看结构化修剪. ==结构化修剪讲道理你可以知道你修剪了哪些结构==.
一个比较好的结构化修建的例子是通过沿着Tensor的某个维度进行裁剪, 这样你可以直接看到维度的变化.
现在开始修剪模块, 比如上面的LeNet的conv1层, 首先我们可以从prune层里面拿一个我们喜欢的技术, 比如基于 ln范数 的评判标准来进行结构话的裁剪.
prune.ln_structured(module, name='weight', amount=0.5, n=2, dim=0)
这个操作之后, 我们得到的将是一个新的权重, 和上面的非结构化的不同的地方在于, 这里是整个矩阵的一行为零, 上面我们用的dim=0, 那么就是channel这一个维度, 会有50%为零.