pytorch BatchNorm参数详解,计算过程

pytorch BatchNorm参数详解,计算过程BatchNorm1d的参数:torch.nn.BatchNorm1d(num_features,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)

大家好,又见面了,我是你们的朋友全栈君。

目录

 

说明

BatchNorm1d参数

num_features

eps

momentum

affine

track_running_stats

BatchNorm1d训练时前向传播

BatchNorm1d评估时前向传播

总结


说明

网络训练时和网络评估时,BatchNorm模块的计算方式不同。如果一个网络里包含了BatchNorm,则在训练时需要先调用train(),使网络里的BatchNorm模块的training=True(默认是True),在网络评估时,需要先调用eval(),使网络里的BatchNorm模块的training=False。

BatchNorm1d参数

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

num_features

输入维度是(N, C, L)时,num_features应该取C;这里N是batch size,C是数据的channel,L是数据长度。

输入维度是(N, L)时,num_features应该取L;这里N是batch size,L是数据长度,这时可以认为每条数据只有一个channel,省略了C

eps

对输入数据进行归一化时加在分母上,防止除零,详情见下文。

momentum

更新全局均值running_mean和方差running_var时使用该值进行平滑,详情见下文。

affine

设为True时,BatchNorm层才会学习参数\gamma\beta,否则不包含这两个变量,变量名是weightbias,详情见下文。

track_running_stats

设为True时,BatchNorm层会统计全局均值running_mean和方差running_var,详情见下文。

BatchNorm1d训练时前向传播

  1. 首先对输入batch求E[x]Var[x],并用这两个结果把batch归一化,使其均值为0,方差为1。归一化公式用到了eps(\epsilon),即y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon }}。如下输入内容,shape是(3, 4),即batch_size=3,此时num_features需要传入4。
    tensor = torch.FloatTensor([[1, 2, 4, 1],
                                [6, 3, 2, 4],
                                [2, 4, 6, 1]])

    此时E[x]=[3, 3, 4, 2]Var[y]_{unbiased}=[7, 1, 4, 3](无偏样本方差)和Var[y]_{biased}=[4.6667, 0.6667, 2.6667, 2.0000](有偏样本方差),有偏和无偏的区别在于无偏的分母是N-1,有偏的分母是N。注意在BatchNorm中,用于更新running_var时,使用无偏样本方差即,但是在对batch进行归一化时,使用有偏样本方差,因此如果batch_size=1,会报错。归一化后的内容如下。

    [[-0.9258, -1.2247,  0.0000, -0.7071],
     [ 1.3887,  0.0000, -1.2247,  1.4142],
     [-0.4629,  1.2247,  1.2247, -0.7071]]
  2. 如果track_running_stats==True,则使用momentum更新模块内部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是x_{new}=(1-momentum)\times x_{cur}+momentum\times x_{batch},其中x_{new}代表更新后的running_meanrunning_varx_{cur}表示更新前的running_meanrunning_varx_{batch}表示当前batch的均值和无偏样本方差。
  3. 如果track_running_stats==False,则BatchNorm中不含有running_meanrunning_var两个变量。
  4. 如果affine==True,则对归一化后的batch进行仿射变换,即乘以模块内部的weight(初值是[1., 1., 1., 1.])然后加上模块内部的bias(初值是[0., 0., 0., 0.]),这两个变量会在反向传播时得到更新。
  5. 如果affine==False,则BatchNorm中不含有weightbias两个变量,什么都都不做。

BatchNorm1d评估时前向传播

  1. 如果track_running_stats==True,则对batch进行归一化,公式为y=\frac{x-\hat{E}[x]}{\sqrt{\hat{Var}[x]+\epsilon }},注意这里的均值和方差是running_meanrunning_var,在网络训练时统计出来的全局均值和无偏样本方差。
  2. 如果track_running_stats==False,则对batch进行归一化,公式为y=\frac{x-{E}[x]}{\sqrt{​{Var}[x]+\epsilon }},注意这里的均值和方差是batch自己的mean和var,此时BatchNorm里不含有running_meanrunning_var注意此时使用的是无偏样本方差(和训练时不同),因此如果batch_size=1,会使分母为0,就报错了。
  3. 如果affine==True,则对归一化后的batch进行放射变换,即乘以模块内部的weight然后加上模块内部的bias,这两个变量都是网络训练时学习到的。
  4. 如果affine==False,则BatchNorm中不含有weightbias两个变量,什么都不做。

总结

在使用batchNorm时,通常只需要指定num_features就可以了。网络训练前调用train(),训练时BatchNorm模块会统计全局running_meanrunning_var,学习weightbias,即文献中的\gamma\beta。网络评估前调用eval(),评估时,对传入的batch,使用统计的全局running_meanrunning_var对batch进行归一化,然后使用学习到的weightbias进行仿射变换。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141980.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 不同维度矩阵相乘[通俗易懂]

    不同维度矩阵相乘[通俗易懂]在深度学习中经常会遇到不同维度的矩阵相乘的情况,本文会通过一些例子来展示不同维度矩阵乘法的过程。总体原则:在高维矩阵中取与低维矩阵相同维度的分片来与低维矩阵相乘,结果再按分片时的顺序还原为高维矩阵。相乘结果的维度与原来的高维矩阵一致。二维乘一维三维乘一维三维乘二维…

    2025年6月18日
    0
  • 图像分割之分水岭算法[通俗易懂]

    图像分割之分水岭算法[通俗易懂]使用C++、opencv进行分水岭分割图像分水岭概念是以对图像进行三维可视化处理为基础的:其中两个是坐标,另一个是灰度级。基于“地形学”的这种解释,我们考虑三类点:a.属于局部性最小值的点,也可能存在一个最小值面,该平面内的都是最小值点b.当一滴水放在某点的位置上的时候,水一定会下落到一个单一的最小值点c.当水处在某个点的位置上时,水会等概率地流向不止一个这样的最小值点对一个特…

    2022年6月16日
    40
  • 指针的赋值和使用[通俗易懂]

    指针的赋值和使用[通俗易懂]更多来自:http://imcc.blogbus.com3.9.3指针的赋值和使用在得到一个指针变量之后,指针变量的值还是一个随机值。这个值可能是内存中无关紧要的数据,也可能是重要的数据或者程

    2022年7月4日
    27
  • 学生网上选课管理系统_选课管理系统

    学生网上选课管理系统_选课管理系统**数据库系统原理课程设计报告**学生选课管理系统(上)设计内容与要求:1、系统用户由三类组成:教师、学生和管理员。2、管理员负责的主要功能:①用户管理(老师、学生及管理员的增、删、改);②课程管理(添加、删除和修改);③选课管理(实现选课功能开放和禁止、老师成绩输入开放和禁止)。3、学生通过登录,可以查询课程的基本信息、实现选课、退课和成绩查询;4、老师通过登录,可以查看选…

    2022年10月16日
    0
  • jmeter基础教程_生活质量和生活品质有什么区别

    jmeter基础教程_生活质量和生活品质有什么区别前言:JMeter一个非常强大的测试工具,给大家简单的介绍一下基本使用方法入门篇,如若不懂,请重新学习小学语文,再来阅读,谢谢!!!1、第一步就安装JMeter,使用JMeter的前提是先把jdk等配置完成,才可以打开JMeter,不然会出现点开没反应的情况我这里展示的是一个改成中文的JMeter,英语好的小伙伴也可以不用改哈默认中文:在jmeter/bin/jmeter.properties在#language=en写入language=zh_CN默认查看结果处理展示编码为u.

    2022年10月21日
    0
  • vue 部署上线清除浏览器缓存「建议收藏」

    vue 部署上线清除浏览器缓存「建议收藏」vue项目打包上线之后,每一次都会有浏览器缓存问题,需要手动的清除缓存。这样用户体验非常不好,所以我们在打包部署的时候需要尽量避免浏览器的缓存。下面是我的解决方案:一、修改根目录index.html在head里面添加下面代码<metahttp-equiv=”pragram”content=”no-cache”><metahttp-equiv=”cache-control”content=”no-cache,no-store,must-revalidate”>

    2022年7月18日
    13

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号