pytorch BatchNorm参数详解,计算过程

pytorch BatchNorm参数详解,计算过程BatchNorm1d的参数:torch.nn.BatchNorm1d(num_features,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)

大家好,又见面了,我是你们的朋友全栈君。

目录

 

说明

BatchNorm1d参数

num_features

eps

momentum

affine

track_running_stats

BatchNorm1d训练时前向传播

BatchNorm1d评估时前向传播

总结


说明

网络训练时和网络评估时,BatchNorm模块的计算方式不同。如果一个网络里包含了BatchNorm,则在训练时需要先调用train(),使网络里的BatchNorm模块的training=True(默认是True),在网络评估时,需要先调用eval(),使网络里的BatchNorm模块的training=False。

BatchNorm1d参数

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

num_features

输入维度是(N, C, L)时,num_features应该取C;这里N是batch size,C是数据的channel,L是数据长度。

输入维度是(N, L)时,num_features应该取L;这里N是batch size,L是数据长度,这时可以认为每条数据只有一个channel,省略了C

eps

对输入数据进行归一化时加在分母上,防止除零,详情见下文。

momentum

更新全局均值running_mean和方差running_var时使用该值进行平滑,详情见下文。

affine

设为True时,BatchNorm层才会学习参数\gamma\beta,否则不包含这两个变量,变量名是weightbias,详情见下文。

track_running_stats

设为True时,BatchNorm层会统计全局均值running_mean和方差running_var,详情见下文。

BatchNorm1d训练时前向传播

  1. 首先对输入batch求E[x]Var[x],并用这两个结果把batch归一化,使其均值为0,方差为1。归一化公式用到了eps(\epsilon),即y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon }}。如下输入内容,shape是(3, 4),即batch_size=3,此时num_features需要传入4。
    tensor = torch.FloatTensor([[1, 2, 4, 1],
                                [6, 3, 2, 4],
                                [2, 4, 6, 1]])

    此时E[x]=[3, 3, 4, 2]Var[y]_{unbiased}=[7, 1, 4, 3](无偏样本方差)和Var[y]_{biased}=[4.6667, 0.6667, 2.6667, 2.0000](有偏样本方差),有偏和无偏的区别在于无偏的分母是N-1,有偏的分母是N。注意在BatchNorm中,用于更新running_var时,使用无偏样本方差即,但是在对batch进行归一化时,使用有偏样本方差,因此如果batch_size=1,会报错。归一化后的内容如下。

    [[-0.9258, -1.2247,  0.0000, -0.7071],
     [ 1.3887,  0.0000, -1.2247,  1.4142],
     [-0.4629,  1.2247,  1.2247, -0.7071]]
  2. 如果track_running_stats==True,则使用momentum更新模块内部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是x_{new}=(1-momentum)\times x_{cur}+momentum\times x_{batch},其中x_{new}代表更新后的running_meanrunning_varx_{cur}表示更新前的running_meanrunning_varx_{batch}表示当前batch的均值和无偏样本方差。
  3. 如果track_running_stats==False,则BatchNorm中不含有running_meanrunning_var两个变量。
  4. 如果affine==True,则对归一化后的batch进行仿射变换,即乘以模块内部的weight(初值是[1., 1., 1., 1.])然后加上模块内部的bias(初值是[0., 0., 0., 0.]),这两个变量会在反向传播时得到更新。
  5. 如果affine==False,则BatchNorm中不含有weightbias两个变量,什么都都不做。

BatchNorm1d评估时前向传播

  1. 如果track_running_stats==True,则对batch进行归一化,公式为y=\frac{x-\hat{E}[x]}{\sqrt{\hat{Var}[x]+\epsilon }},注意这里的均值和方差是running_meanrunning_var,在网络训练时统计出来的全局均值和无偏样本方差。
  2. 如果track_running_stats==False,则对batch进行归一化,公式为y=\frac{x-{E}[x]}{\sqrt{​{Var}[x]+\epsilon }},注意这里的均值和方差是batch自己的mean和var,此时BatchNorm里不含有running_meanrunning_var注意此时使用的是无偏样本方差(和训练时不同),因此如果batch_size=1,会使分母为0,就报错了。
  3. 如果affine==True,则对归一化后的batch进行放射变换,即乘以模块内部的weight然后加上模块内部的bias,这两个变量都是网络训练时学习到的。
  4. 如果affine==False,则BatchNorm中不含有weightbias两个变量,什么都不做。

总结

在使用batchNorm时,通常只需要指定num_features就可以了。网络训练前调用train(),训练时BatchNorm模块会统计全局running_meanrunning_var,学习weightbias,即文献中的\gamma\beta。网络评估前调用eval(),评估时,对传入的batch,使用统计的全局running_meanrunning_var对batch进行归一化,然后使用学习到的weightbias进行仿射变换。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141980.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月24日 下午6:20
下一篇 2022年5月24日 下午6:20


相关推荐

  • 智谱港股上市:市值超500亿港元

    智谱港股上市:市值超500亿港元

    2026年3月12日
    2
  • PyCharm和git安装教程

    PyCharm和git安装教程先到官网下载 githttps git scm com download win 进入 setting 如黄色部分如果你用的是 github 那么直接 setting 登陆就行了如果你是 gitee 的话首先进入 setting 然后 Plugins 点击 browse 查找 gitee 如图所示 最后点击重启 ok 不要自己关闭 否则安装失败 安装好了以后 这里走了一些弯路省去不写 直接写正确答案 根据经验

    2026年3月27日
    1
  • javaweb权限管理简单实现_javaweb用户权限管理

    javaweb权限管理简单实现_javaweb用户权限管理推荐最新技术springboot版权限管理(java后台通用权限管理系统(springboot)),采用最新技术架构,功能强大!注:由于该项目比较老,所以没有采用maven管理,建议下载springboot权限管理系统,对学习和使用会更有帮助。springboot权限管理系统介绍地址:https://blog.csdn.net/zwx19921215/article/details/978……………

    2026年4月14日
    4
  • 结构体和类使用的区别

    结构体和类使用的区别前段时间写推力叠加时遇到的一个问题 我当时最开始是用两个列表分别存储由推力和方向得到的速度 速度的持续时间 下标一一对应 后来觉得可以用结构体来存储速度和持续时间 这样就只需要一个列表管理就可以了 能少用一次遍历 同时更好的面向对象吧 然后用结构体改的时候却发现结构体里的字段不能直接用来加等减等运算 因为结构体里的字段是放在栈里的 为值类型 后来就把结构体改成类解决了这个问题 因为类的存储是在堆里

    2026年3月19日
    0
  • php三个数从大到小排列_单分支if语句和双分支

    php三个数从大到小排列_单分支if语句和双分支<?php$a = rand(100,999);$b = rand(100,999);$c = rand(100,999);echo “a=”.”$a”.”<br>”;echo “b=”.”$b”.”<br>”;echo “c=”.”$c”.”<br>”;if(($a > $b ) && ($a > …

    2022年8月18日
    12
  • cisco交换机常用命令[通俗易懂]

    一台全新交换机,不同模式命令大全交换机基本状态显示及各个状态切换:hostname&gt;用户模式hostname#特权模式hostname(config)#全局配置模式hostname(config-if)#接口或者vlan多个接口配置模式hostname(vlan)#…

    2022年4月7日
    52

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号