pytorch固定BN层参数[通俗易懂]

pytorch固定BN层参数[通俗易懂]背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。原因:未固定主分支BN层中的running_mean和running_var。解决方法:将需要固定的BN层状态设置为eval。问题示例:环境:torch:1.7.0#-*-coding:utf-8-*-importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.M

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:未固定主分支BN层中的running_meanrunning_var

解决方法:将需要固定的BN层状态设置为eval

问题示例

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight:   True
conv1.bias:     True
bn1.weight:     True
bn1.bias:       True
conv2.weight:   True
conv2.bias:     True
bn2.weight:     True
bn2.bias:       True
fc1.weight:     True
fc1.bias:       True
fc2.weight:     True
fc2.bias:       True
fc3.weight:     True
fc3.bias:       True
-------parameters requires grad info--------
conv1.weight:   False
conv1.bias:     False
bn1.weight:     False
bn1.bias:       False
conv2.weight:   False
conv2.bias:     False
bn2.weight:     False
bn2.bias:       False
fc1.weight:     False
fc1.bias:       False
fc2.weight:     False
fc2.bias:       False
fc3.weight:     False
fc3.bias:       False
epoch:0 tensor([[-0.0755,  0.1138,  0.0966,  0.0564, -0.0224]])
epoch:1 tensor([[-0.0763,  0.1113,  0.0970,  0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_meanrunning_var并没在可优化参数net.parameters

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])

参考:

Cannot freeze batch normalization parameters

BatchNorm2d

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/183783.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • RT-thread finsh移植到linux平台

    RT-thread finsh移植到linux平台目录FinSH介绍传统命令行模式C语言解释器模式FinSH移植移植要点效果验证代码下载参考在一次项目中,需要进行嵌入式操作系统选型,需求就是选择一款OS,既能满足当下项目的需要,又要考虑公司未来对物联网应用的扩展能力,对比了目前市面上流行的开源操作系统,诸如FreeRTOS,RTX,UCOS,RT-Thread,contiki等,最终确定了一款Io…

    2022年5月21日
    31
  • phpstorm 激活码【注册码】

    phpstorm 激活码【注册码】,https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月20日
    34
  • a4988 脉宽要求_A4988驱动模块使用详解(附:电流调节方法)

    a4988 脉宽要求_A4988驱动模块使用详解(附:电流调节方法)DIY3D打印机的时候,各种驱动、主板、固件等的最让人头疼,稍不注意就有可能烧机….这方面的知识不补不行啊。今天给大家介绍下A4988驱动,很小很便宜的一个部件,但学问不少哦,一起来看看吧。A4988简介A4988是一款完全的微步电动机驱动器,带有内置转换器,易于操作。该产品可在全、半、1/4、1/8及1/16步进模式时操作双极步进电动机,输出驱动性能可达35V及±1A。…

    2022年6月16日
    76
  • oracle创建索引和删除索引_oracle删除索引语句

    oracle创建索引和删除索引_oracle删除索引语句索引的创建createindexindex_nameontable(column_name1,column_name2);创建唯一索引createindexuniqueindex_nameontable(column_name1,column_name2);索引的删除。dropindexindex_name;以下两条语句是…

    2022年9月5日
    4
  • vm虚拟机安装win11_虚拟机15.5安装教程win7

    vm虚拟机安装win11_虚拟机15.5安装教程win7首先下载好虚拟机以及系统,并且把iso镜像解压好!打开虚拟机! 首先,选择创建虚拟机,然后选择典型.点击下一步! 选择你刚才下载的iso镜像文件.点击下一步! 选择XP版本,点击下一步,下一步是系统的存放位置,和系统名字,看自己怎么样方便吧!在点击下一步,是磁盘空间,这个随便选都可以,如果安装的系统系统用多少内存,就会消耗本机硬盘多少内存,没关系…

    2022年8月16日
    3
  • linux lseek

    linux lseek

    2022年6月25日
    20

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号