pytorch固定BN层参数[通俗易懂]

pytorch固定BN层参数[通俗易懂]背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。原因:未固定主分支BN层中的running_mean和running_var。解决方法:将需要固定的BN层状态设置为eval。问题示例:环境:torch:1.7.0#-*-coding:utf-8-*-importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.M

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:未固定主分支BN层中的running_meanrunning_var

解决方法:将需要固定的BN层状态设置为eval

问题示例

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight:   True
conv1.bias:     True
bn1.weight:     True
bn1.bias:       True
conv2.weight:   True
conv2.bias:     True
bn2.weight:     True
bn2.bias:       True
fc1.weight:     True
fc1.bias:       True
fc2.weight:     True
fc2.bias:       True
fc3.weight:     True
fc3.bias:       True
-------parameters requires grad info--------
conv1.weight:   False
conv1.bias:     False
bn1.weight:     False
bn1.bias:       False
conv2.weight:   False
conv2.bias:     False
bn2.weight:     False
bn2.bias:       False
fc1.weight:     False
fc1.bias:       False
fc2.weight:     False
fc2.bias:       False
fc3.weight:     False
fc3.bias:       False
epoch:0 tensor([[-0.0755,  0.1138,  0.0966,  0.0564, -0.0224]])
epoch:1 tensor([[-0.0763,  0.1113,  0.0970,  0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_meanrunning_var并没在可优化参数net.parameters

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])

参考:

Cannot freeze batch normalization parameters

BatchNorm2d

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/183783.html原文链接:https://javaforall.net

(0)
上一篇 2022年8月31日 下午3:36
下一篇 2022年8月31日 下午3:46


相关推荐

  • 趁热打铁!HTTPGet 与HTTPPost的区别

    趁热打铁!HTTPGet 与HTTPPost的区别今天在老师工作室做项目的时候 突然看到一个页面用了 2 种不同的传值类型 突然有了兴趣 想弄明白本质的区别 虽然以前用的知道 2 种的用法 但是还是云里雾里的 下面是那位大神的文章 原文链接 nbsp 作者 WebTechGarde 和 POST 是 HTTP 请求的两种基本方法 要说它们的区别 接触过 WEB 开发的人都能说出一二 最直观的区别就是 GET 把参数包含在 URL 中 POST 通过 requestbody 传递参数

    2026年3月18日
    2
  • busybox 安装mysql_安装busybox「建议收藏」

    busybox 安装mysql_安装busybox「建议收藏」安装busybox按以下步骤即可:1.root手机2.查看手机支持的cpu架构:cat/system/build.prop|grepabi我手机查出来的结果如下所示:ro.product.cpu.abi=armeabi-v7aro.product.cpu.abi2=armeabi3.下载适合你手机的Busybox,可以直接下载binary文件,地址如下:https://busybox.ne…

    2022年7月25日
    14
  • 传感器尺寸、像素、DPI分辨率、英寸、毫米的关系

    传感器尺寸、像素、DPI分辨率、英寸、毫米的关系虽然网上有很多这种资料,但是太过于复杂,每个人的说法都不一样,看的让人云里雾里的,我总结了一下,不知道对不对!1.1英寸=25.4mm2.传感器尺寸:传感器的尺寸是指传感器的大小,一般描述大小有两种形式,以IMX386感光元件为例,其传感器尺寸1/2.9英寸,是指传感器对角线为1/2.9英寸;还可以描述成传感器尺寸4.97mm×6.2mm,是指水平(竖直)长(宽)为4.97(6.2)m…

    2022年6月13日
    57
  • R语言多个for循环嵌套使用

    R语言多个for循环嵌套使用R 语言 多个 for 循环联合使用提高数据处理效率

    2026年3月17日
    3
  • 图像特征提取总结_将劣势转化为优势的例子

    图像特征提取总结_将劣势转化为优势的例子转载地址:https://blog.csdn.net/lskyne/article/details/8654856 特征提取是计算机视觉和图像处理中的一个概念。它指的是使用计算机提取图像信息,决定每个图像的点是否属于一个图像特征。特征提取的结果是把图像上的点分为不同的子集,这些子集往往属于孤立的点、连续的曲线或者连续的区域。 特征的定义        至今为止特征没有万能和精确的定义。…

    2025年7月8日
    4
  • Android 加密 AES

    Android 加密 AESAES加密又称对称性加密,在开发中常用于对流数据对加密,尤其是流数据在网络传输过程中,担心被泄露,AES加密被常用于这块的校验中。下面是AES加密的百度百科说明解释:AES加密标准又称为高级加密标准Rijndael加密法,是美国国家标准技术研究所NIST旨在取代DES的21世纪的加密标准。AES的基本要求是,采用对称分组密码体制,密钥长度可以为128、192或256…

    2022年5月16日
    36

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号