Pytorch distiller库问题整理 – 简书项目主页文档地址 1. Invoked with: OperatorExportTypes.RAW, False 在使用ditiller库中的apputils.Summa…https://www.jianshu.com/p/bf3a46791f47
Introduction
名词解释:
feature map:(FM for short),每个卷积层的input plane。
input channel:就像RGB channel一样。
activation:每个卷积层的输出。
convolution output shape(N,Cout,Hout,Wout),其中N是batch size,Cout表示输出channel,Hout是输出channel的height,Wout是width。我们不会对batch size大小有很多关注,因为它对我们的讨论不重要,因此不失一般性我们将N设为1.卷积weight为:(F,C,K,K),F为filter的数量,C是channel的数量,K是kernel大小。每个kernel与一个2Dinput channel(也就是feature map)进行卷积,所以如果输入有Cin个channel,一个filter中就有Cin个kernel。每个filter与整个input进行卷积产生一个output channel(feature map),如果有Cout个output channel,就有Cout个filter。
Filter Pruning
pruners: example_pruner: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.50 group_type: Filters weights: [module.conv1.weight]
考虑在两个卷积层之间添加一个batch-norm层:

Batch-Normalization层由几个tensor进行参数化,这些tensor包含每个input channel的信息(即scale和shift)。因为我们的卷积产生的输出FM较少,而这些FM是BN层的输入,所以我们还需要重新配置BN层。而且,我们还需要缩小BN层的scale和shift tensor,它们是BN输入变换中的系数。此外,我们从tensor中删除的scale和shift系数必须与我们从卷积权重tensor中删除的filter(或输出feature map channel)相对应。这个细节会让人很痛苦,但我们将在后面的示例中解决。上面示例中BN层的存在对我们没有影响,实际上,YAMLschedule不变。Distiller会检测到BN层的存在并自动调整其参数。
考虑另一个与非串行数据相关的示例。在此,conv1的输出是conv2和conv3的输入。这是并行数据依赖的示例。
由于我们仍然仅显式修剪conv1的weight filter,因此YAML schedule与前两个示例相同。conv2和conv3的weight channel在distiller的一个名为“Thinning”的process中被隐式修剪。
另一个涉及到三个卷积的例子,这次我们想剪掉两个卷积层的filter,其输出逐元素相加并喂入第三个卷积。在这个例子中conv3对conv1和conv2都有依赖性,对这种依赖性有两种含义。第一个且是最明显的含义是,我们需要从conv1和conv2中修剪相同数量的filter。因为我们对conv1和conv2的输出进行了逐元素相加,它们必须有同一shape,且如果conv1和conv2修剪掉同样数量的filter他们也只能有同一形状。这种三角数据依赖性的第二层含义是,conv1和conv2必须剪掉同样的filter。如果如下图中忽略掉第二个约束,就会出现两难境地:不能修剪conv3的权重channel。

我们必须应用第二个约束,这也意味着我们现在需要积极主动:我们需要通过conv1或conv2的filter剪枝选择来决定是否修剪conv1和conv2.下图解释了在决定根据conv1的剪枝choice后进行剪枝。
YAML schedule语法需要能够表达上面讨论的两个约束。首先,我们需要告诉filter pruner我们存在一个leader类型的依赖项。这意味着该weights中列出的所有tensor都将在每次迭代中以相同程度进行修剪,并且为了修剪filter,我们将使用列出的第一个tensor的修剪决策。在下面的示例中,通过module.conv1.weight 的剪枝选择来共同修剪module.conv1.weight和module.conv2.weight.
pruners: example_pruner: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.10 final_sparsity: 0.50 group_type: Filters group_dependency: Leader weights: [module.conv1.weight, module.conv2.weight]
当我们对ResNet应用filter-pruning时,由于skip connection,我们会看到一些很长的依赖链。如果不注意,会很容易的未指定(或错误指定)依赖链并且distiller会退出并报错。该异常不能解释规范错误,因此需要改进。
channel pruning
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/225630.html原文链接:https://javaforall.net
