1.distiller剪枝模块的使用
(1)distiller自带剪枝实例测试
distiller自带一些测试实例如ResNet56+cifar-10,下面是对ResNet56+cifar-10的测试:
- 测试前准备
- yaml文件(注意:这里的yaml文件是coder配置好的,具体到自己的模型需要先对自己的model进行一次Sparsity Analysis,然后自己配置该文件) 在剪枝时所用到的yaml文件作用主要是配置了一些剪枝所需要的必要信息,比如下面ResNet56所需要用的yaml配置文件(路径:distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml):
version: 1 # 版本 pruners: filter_pruner_60: # 后面的60表示剪掉60%的Filters,如[16, 16, 3, 3]剪掉之后就是[7, 16, 3, 3] class: 'L1RankedStructureParameterPruner' # 表示所使用的算法,这里使用L1Rank group_type: Filters # 表示剪切类型,一般两种Filters/Channel desired_sparsity: 0.6 # 剪掉60%的Filters weights: [ # 下面是一些具体的需要剪切的权值 module.layer1.0.conv1.weight, module.layer1.1.conv1.weight, module.layer1.2.conv1.weight, module.layer1.3.conv1.weight, module.layer1.4.conv1.weight, module.layer1.5.conv1.weight, module.layer1.6.conv1.weight, module.layer1.7.conv1.weight, module.layer1.8.conv1.weight] filter_pruner_50: # 同上 class: 'L1RankedStructureParameterPruner' group_type: Filters desired_sparsity: 0.5 weights: [ module.layer2.1.conv1.weight, module.layer2.2.conv1.weight, module.layer2.3.conv1.weight, module.layer2.4.conv1.weight, module.layer2.6.conv1.weight, module.layer2.7.conv1.weight] filter_pruner_10: # 同上 class: 'L1RankedStructureParameterPruner' group_type: Filters desired_sparsity: 0.1 weights: [module.layer3.1.conv1.weight] filter_pruner_30: # 同上 class: 'L1RankedStructureParameterPruner' group_type: Filters desired_sparsity: 0.3 weights: [ module.layer3.2.conv1.weight, module.layer3.3.conv1.weight, module.layer3.5.conv1.weight, module.layer3.6.conv1.weight, module.layer3.7.conv1.weight, module.layer3.8.conv1.weight] extensions: net_thinner: class: 'FilterRemover' thinning_func_str: remove_filters arch: 'resnet56_cifar' # 使用的网络 dataset: 'cifar10' # 数据集 lr_schedulers: exp_finetuning_lr: class: ExponentialLR gamma: 0.95 policies: - pruner: instance_name: filter_pruner_60 epochs: [0] - pruner: instance_name: filter_pruner_50 epochs: [0] - pruner: instance_name: filter_pruner_30 epochs: [0] - pruner: instance_name: filter_pruner_10 epochs: [0] - extension: instance_name: net_thinner epochs: [0] - lr_scheduler: instance_name: exp_finetuning_lr starting_epoch: 10 ending_epoch: 300 frequency: 1
- 准备ResNet-56需要的模型文件,可下载:
https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
- 剪枝
找到compress_classifier.py文件,如下:

$python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0
参数解释: -a 表示模型名称(这里是工具自带的模型名称,其他的如resnet32_cifar, resnet44_cifar, resnet56_cifar等等 cifar的模型代码文件位于distiller/models/cifar10/resnet_cifar.py)
-p表示每隔多少打印一次
../../../data.cifar10是数据集路径
–epochs 表示剪枝过后继续训练次数
–compress 表示所用的‘策略’(compress_scheduler),一般是yaml文件的路径
–resume-from 表示保存的模型的路径
–reset-optimizer 如果设置此参数,那么start_epoch=0,将optimizer重置为SGD, 学习绿设置为传入的学习率
–vs validation-split
具体的其他参数参看distiller/apputils/image_classifier.py文件和distiller/quantization/range_linear.py文件以及github上参数解释。
运行时会对模型进行剪枝,然后在测试集上测试,打印出top1和top5以及loss,运行结束后量化模型会保存在logs下。
(2)distiller对自己的模型剪枝
- 具体流程:1.Sparsity Analysis 分析各层weight的sensitivity,即对模型各个部分的稀疏性和可以pruning的程度有个了解。 2. Yaml file create 创建一个属于自己模型的配置文件,里面各层的稀疏度是由第一阶段分析出来的sensitivity而来的 3.Thinning 对网络进行真正的剪枝。
(3)代码
- 代码已全部整理在 https://github.com/BlossomingL/compression_tool
- distiller修改代码在 https://github.com/BlossomingL/Distiller
- 有用的话麻烦手动点一下小星星,谢谢!
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/230467.html原文链接:https://javaforall.net
