【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数

【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数基本原理在卷积神经网络的卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定,BatchNorm2d()函数数学原理如下:BatchNorm2d()内部的参数如下:1.num_features:一般输…

大家好,又见面了,我是你们的朋友全栈君。

基本原理

在卷积神经网络的卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定,BatchNorm2d()函数数学原理如下:

                                                      【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数

BatchNorm2d()内部的参数如下:

1.num_features:一般输入参数为batch_size*num_features*height*width,即为其中特征的数量

2.eps:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5

3.momentum:一个用于运行过程中均值和方差的一个估计参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数)

4.affine:当设为true时,会给定可以学习的系数矩阵gamma和beta

上面的讲解还不够形象,我们具体通过如下的代码进行讲解:

代码演示

#encoding:utf-8
import torch
import torch.nn as nn
#num_features - num_features from an expected input of size:batch_size*num_features*height*width
#eps:default:1e-5 (公式中为数值稳定性加到分母上的值)
#momentum:动量参数,用于running_mean and running_var计算的值,default:0.1
m=nn.BatchNorm2d(2,affine=True) #affine参数设为True表示weight和bias将被使用
input=torch.randn(1,2,3,4)
output=m(input)

print(input)
print(m.weight)
print(m.bias)
print(output)
print(output.size())

具体的输出如下:

tensor([[[[ 1.4174, -1.9512, -0.4910, -0.5675],
          [ 1.2095,  1.0312,  0.8652, -0.1177],
          [-0.5964,  0.5000, -1.4704,  2.3610]],

         [[-0.8312, -0.8122, -0.3876,  0.1245],
          [ 0.5627, -0.1876, -1.6413, -1.8722],
          [-0.0636,  0.7284,  2.1816,  0.4933]]]])
Parameter containing:
tensor([0.2837, 0.1493], requires_grad=True)
Parameter containing:
tensor([0., 0.], requires_grad=True)
tensor([[[[ 0.2892, -0.4996, -0.1577, -0.1756],
          [ 0.2405,  0.1987,  0.1599, -0.0703],
          [-0.1824,  0.0743, -0.3871,  0.5101]],

         [[-0.0975, -0.0948, -0.0347,  0.0377],
          [ 0.0997, -0.0064, -0.2121, -0.2448],
          [ 0.0111,  0.1232,  0.3287,  0.0899]]]],
       grad_fn=<NativeBatchNormBackward>)
torch.Size([1, 2, 3, 4])

分析:输入是一个1*2*3*4 四维矩阵,gamma和beta为一维数组,是针对input[0][0],input[0][1]两个3*4的二维矩阵分别进行处理的,我们不妨将input[0][0]的按照上面介绍的基本公式来运算,看是否能对的上output[0][0]中的数据。首先我们将input[0][0]中的数据输出,并计算其中的均值和方差。

print("输入的第一个维度:")
print(input[0][0]) #这个数据是第一个3*4的二维数据
#求第一个维度的均值和方差
firstDimenMean=torch.Tensor.mean(input[0][0])
firstDimenVar=torch.Tensor.var(input[0][0],False)   #false表示贝塞尔校正不会被使用
print(m)
print('m.eps=',m.eps)
print(firstDimenMean)
print(firstDimenVar)

输出结果如下:

输入的第一个维度:
tensor([[ 1.4174, -1.9512, -0.4910, -0.5675],
        [ 1.2095,  1.0312,  0.8652, -0.1177],
        [-0.5964,  0.5000, -1.4704,  2.3610]])
BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
m.eps= 1e-05
tensor(0.1825)
tensor(1.4675)

我们可以通过计算器计算出均值和方差均正确计算。最后通过公式计算input[0][0][0][0]的值,代码如下:

batchnormone=((input[0][0][0][0]-firstDimenMean)/(torch.pow(firstDimenVar,0.5)+m.eps))\
    *m.weight[0]+m.bias[0]
print(batchnormone)

输出结果如下:

tensor(0.2892, grad_fn=<AddBackward0>)

结果值等于output[0][0][0][0]。ok,代码和公式完美的对应起来了。

ps:上面计算方差时有一个贝塞尔校正系数,具体可以通过如下链接参考:https://www.jianshu.com/p/8dbb2535407e 

从公式上理解即在计算方差时一般的计算方式如下:

                                                   【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数

通过贝塞尔校正的样本方差如下:

                                                   【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数

目的是在总体中选取样本时能够防止边缘数据不被选到。详细的理解可以参考上面的链接。

 

 

 

 

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/130225.html原文链接:https://javaforall.net

(1)
上一篇 2022年6月14日 下午6:16
下一篇 2022年6月14日 下午6:16


相关推荐

  • JavaScript算法笔试[通俗易懂]

    JavaScript算法笔试[通俗易懂]连续子数组的最大和翻转UL列表字符串去重

    2026年4月15日
    6
  • leetcode-150. 逆波兰表达式求值(栈)

    leetcode-150. 逆波兰表达式求值(栈)根据 逆波兰表示法,求表达式的值。有效的算符包括 +、-、*、/ 。每个运算对象可以是整数,也可以是另一个逆波兰表达式。说明:整数除法只保留整数部分。给定逆波兰表达式总是有效的。换句话说,表达式总会得出有效数值且不存在除数为 0 的情况。 示例 1:输入:tokens = [“2″,”1″,”+”,”3″,”*”]输出:9解释:该算式转化为常见的中缀算术表达式为:((2 + 1) * 3) = 9示例 2:输入:tokens = [“4″,”13″,”5″,”/”,”+”]输

    2022年8月11日
    7
  • OpenClaw“小龙虾”,创始人的必修课——趋势、耐心、风险与安全

    OpenClaw“小龙虾”,创始人的必修课——趋势、耐心、风险与安全

    2026年3月15日
    2
  • 浅谈SQL游标

    浅谈SQL游标
    游标(Cursor)是处理数据的一种方法,为了查看或者处理结果集中的数据,游标提供了在结果集中一次以行或者多行前进或向后浏览数据的能力。我们可以把游标当作一个指针,它可以指定结果中的任何位置,然后允许用户对指定位置的数据进行处理。游标允许你选择一组数据,通过翻阅这组数据记录——通常被称为数据集,检查每一个游标所在的特定的行。你可以将游标和局部变量组合在一起对每一个记录进行检查,当游标移动到下一个记录时,来执行一些外部操作。游标的另一个常见的用法是:保存查询结果以备以后使用。一个游标结果集是通过执

    2022年7月12日
    29
  • 给定一个n个正整数组成的数组_算法基础课acwing下载

    给定一个n个正整数组成的数组_算法基础课acwing下载给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:C l r d,表示把 A[l],A[l+1],…,A[r] 都加上 d。Q l r,表示询问数列中第 l∼r 个数的和。对于每个询问,输出一个整数表示答案。输入格式第一行两个整数 N,M。第二行 N 个整数 A[i]。接下来 M 行表示 M 条指令,每条指令的格式如题目描述所示。输出格式对于每个询问,输出一个整数表示答案。每个答案占一行。数据范围1≤N,M≤105,|d|≤10000,|A[i]|≤1

    2022年8月10日
    10
  • kotlin 初始化数组

    kotlin 初始化数组为什么 80 的码农都做不了架构师 gt gt gt

    2026年3月26日
    1

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号