RNN-bptt简单推导「建议收藏」

RNN-bptt简单推导「建议收藏」摘要:在前面的文章里面,RNN训练与BP算法,我们提到了RNN的训练算法。但是回头看的时候在时间的维度上没有做处理,所以整个推导可能存在一点问题。那么,在这篇文章里面,我们将介绍bptt(BackPropagationThroughTime)算法如在训练RNN。关于bptt这里首先解释一下所谓的bptt,bptt的思路其实很简单,就是把整个RNN按时间的维度展

大家好,又见面了,我是你们的朋友全栈君。

摘要:


在前面的文章里面,RNN训练与BP算法,我们提到了RNN的训练算法。但是回头看的时候在时间的维度上没有做处理,所以整个推导可能存在一点问题。

那么,在这篇文章里面,我们将介绍bptt(Back Propagation Through Time)算法如在训练RNN。

关于bptt


这里首先解释一下所谓的bptt,bptt的思路其实很简单,就是把整个RNN按时间的维度展开成一个“多层的神经网络”。具体来说比如下图:
这里写图片描述

既然RNN已经按时间的维度展开成一个看起来像多层的神经网络,这个时候用普通的bp算法就可以同样的计算,只不过这里比较复杂的是权重共享。比如上图中每一根线就是一个权重,而我们可以看到在RNN由于权重是共享的,所以三条红线的权重是一样的,这在运用链式法则的时候稍微比较复杂。

正文:


首先,和以往一样,我们先做一些定义。
hti=f(netthi)

netthi=m(vimxtm)+s(uisht1s)

nettyk=mwkmhtm
最后一层经过softmax的转化
otk=enettykkenettyk
在这里我们使用交叉熵作为Loss Function
Et=kztklnotk

我们的任务同样也是求 Ewkm Evim Euim
注意,这里的 E 没有时间的下标。因为在RNN里,这些梯度分别为各个时刻的梯度之和。
即:


Ewkm=stept=0Etwkm

Evim=stept=0Etvim
Euim=stept=0Etuim

所以下面我们推导的是 Etwkm Etvim Etuim

我们先推导 Etwkm
Etwkm=kEtotkotknettyknettykwkm=(otkztk)htm 。(这一部分的推导在前面的文章已经讨论过了)。
在这里,记误差信号:
δ(output,t)k=Etnettyk=kEtotkotknettyk=(otkztk) (后面会用到)

对于 Etvim Etuim 其实是差不多的,所以这里详细介绍其中一个。这两个导数也是RNN里面最复杂的。

推导: Etvim

Etvim=tt=0Etnetthinetthivim
对于这个式子第一次看可能有点懵逼,这里稍微解释一下:
从式: hti=f(m(vimxtm)+s(uisht1s)) 中我们可以看到, vim 影响的是所有时刻的 netthi,t=0,1,2,....step 。所以当 Et vim 求偏导的时候,由于链式法则需要考虑到所有时刻的 netthi

下面分成两部分来求 Etnetthi netthivim.
第一部分: Etnetthi
这里我们记 δ(t,t)i=Etnetthi (误差信号,和前面文章一样)。



(由于带着符号去求这两个导数会让人看起来非常懵逼,所以下面指定具体的值,后面抽象给出通式)
假设共3个时刻,即t=0,1,2。
对于 t=2 t=2 时:
E2 表示第2个时刻(也是最后一个时刻)的误差)
net2hi 表示第2个时刻隐藏层第i个神经元的净输入)
具体来说: E2net2hi=E2h2ih2inet2hi

对于 E2h2i=kE2net2yknet2ykh2i
由于 δ(output,t)k=Etnettyk
所以,我们有:
E2h2i=kE2net2yknet2ykh2i=kδ(output,2)knet2ykh2i=kδ(output,2)kwki
综上:
δ(2,2)i=E2net2hi=E2h2ih2inet2hi=(kδ(output,2)kwki)f(net2hi)

对于 t=1 t=2 时:
E2 表示第2个时刻的误差)
net1hi 表示第1个时刻隐藏层第i个神经元的净输入)
具体来说: E2net1hi=E2h1ih1inet1hi
那么 E2h1i=kE2net1yknet1ykh1i+jE2net2hjnet2hjh1i 。请对比这个式子和上面 t=2 t=2 时的区别,区别在于多了一项 jE2net2hjnet2hjh1i 。这个原因我们已经在RNN与bp算法中讨论过,这里简单的说就是由于 t=1 时刻有 t=2 时刻反向传播回来的误差,所以要考虑上这一项,但是对于 t=2 已经是最后一个时刻了,没有反向传播回来的误差。

对于第一项 kE2net1yknet1ykh1i 其实是0。下面简单分析下原因:
上式进一步可以化为: k(kE2o1ko1knet1yk)net1ykh1i E2 与第1个时刻输出 o1k 无关。所以为0。

对于第二项 jE2net2hjnet2hjh1i ,我们带入 δ(t,t)i=Etnetthi 有:
jE2net2hjnet2hjh1i=jδ(2,2)jnet2hjh1i
同时明显有 net2hjh1i=uji
即: E2h1i=jδ(2,2)juji

综上:
δ(1,2)i=E2net1hi=E2h1ih1inet1hi=(jδ(2,2)jnet2hjh1i)f(net1hi)=(jδ(2,2)juji)f(net1hi)

对于 t=0 t=2 时:
E2 表示第2个时刻的误差)
net0hi 表示第0个时刻隐藏层第i个神经元的净输入)。
和上面的思路一样,我们容易得到:
δ(0,2)i=E2net0hi=(jδ(1,2)juji)f(net0hi)

至此,我们求完了 Etnetthi 。下面我们来总结一下其通式:

Etnetthi=δ(t,t)i={
(kδ(output,t)kwki)f(netthi),(jδ(t+1,t)juji)f(netthi),t=ttt


另外,对于 δ(output,t)k 有以下表达式:
δ(output,t)k=Etnettyk=kEtotkotknettyk=(otkztk)



最后只要求出 netthivim ,其值具体为 netthivim=xtm


最后,对于 Etuim 其实和上面的差不多,主要是后面的部分不一样,具体来说:
Etuim=tt=0Etnetthinetthiuim ,可以看到就只有等式右边的第二项不一样,关键部分是一样的。 netthiuim=ht1m

细节-1


上面提到,当只有3个时刻时,t=0,1,2。
对于误差 E2 (最后一个时刻的误差),没有再下一个时刻反向传回的误差。
那么对于 E1 (第1个时刻的误差)存在下一个时刻反向传回的误差,但是在 E1h1i 中的第二项 jE1net2hjnet2hjh1i 仍然为0。是因为 E1net2hj=0 ,因为 E1 的误差和下一个时刻隐藏层的输出没有任何关系。

总结


看起来bptt和我们之前讨论的bp本质上是一样的,只是在一些细节的处理上由于权重共享的原因有所不同,但是基本上还是一样的。

下面这篇文章是有一个简单的rnn代码,大家可以参考一下
参考文章1
代码的bptt中每一步的迭代公式其实就是上面的公式。希望对大家有帮助~

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152320.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • MyEclipse10.6 安装SVN插件方法及插件下载地址

    MyEclipse10.6 安装SVN插件方法及插件下载地址方法一:在线安装1.打开HELP->MyEclipseConfigurationCenter。切换到SoftWare标签页。 2.点击AddSite打开对话框,在对话框Name输入Svn,URL中输入:http://subclipse.tigris.org/update_1.6.x3.在左边栏中找到PersonalSite中找到SVN展开。将Core

    2022年7月20日
    15
  • ARM64架构、国产系统UOS、银河麒麟离线安装jdk1.7、jdk1.8,jdk7、jdk8离线安装(100%成功)

    ARM64架构、国产系统UOS、银河麒麟离线安装jdk1.7、jdk1.8,jdk7、jdk8离线安装(100%成功)Linuxarm64架构下安装jdk1.7、jdk1.8说明:理论上适用于arm64架构的Linux系统,目前在银河麒麟、UOS测试可安装通过1.挂载ISO介质上传Kylin-4.0.2-FT2000Plus.iso到服务器到/opt/目录下,(如果没有该介质,请向笔者索要,网盘下载)创建挂载目录mkdir/mnt/apt挂载isomount/opt/Kylin-4.0.2-FT2000Plus.iso/mnt/apt2.修改本地源先备份本地源cp/et

    2022年6月3日
    312
  • 如何删除sqlserver实例_sql server删除表

    如何删除sqlserver实例_sql server删除表在网上找到下面几种方法,本人使用的是第一种,很实用。1.删除SQLServer的特定实例若要删除SQLServer的某个特定实例,请按照以下步骤操作:找到并删除%drive%:\\ProgramFiles\\MicrosoftSQLServer\\MSSQL\\Binn文件夹,其中%drive%是要删除的SQLServer实例的位置。找到以下注册表项:HKEY…

    2022年10月2日
    2
  • document.onreadystatechange_js转json格式

    document.onreadystatechange_js转json格式标准参考无。问题描述onreadystatechange事件通常用在基于XMLHttpRequest对象的AJAX应用中,当的该对象的loadstate改变时,会触发此事件。但在IE

    2022年8月2日
    6
  • string对象下标越界

    string对象下标越界#include<iostream>#include<string>usingnamespacestd;intmain(){stringa;cin>>a[0];cin>>a[1];return0;}最近写代码时发生了这一问题,就是上边的程序,运行后会出现数组越界。其实这是一个非常小的问题,原因是我自己把string当成了一个无穷大的数组,string可以无穷大,但是这并不能将他当成无穷大数组.

    2022年9月26日
    0
  • raft算法详解_python raft

    raft算法详解_python raft  raft是工程上使用较为广泛的强一致性、去中心化、高可用的分布式协议。在这里强调了是在工程上,因为在学术理论界,最耀眼的还是大名鼎鼎的Paxos。但Paxos是:少数真正理解的

    2022年8月4日
    10

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号