详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」上一节我们说了详细展示RNN的网络结构以及前向传播,在了解RNN的结构之后,如何训练RNN就是一个重要问题,训练模型就是更新模型的参数,也就是如何进行反向传播,也就意味着如何对参数进行求导。本篇内容就是详细介绍RNN的反向传播算法,即BPTT。首先让我们来用动图来表示RNN的损失是如何产生的,以及如何进行反向传播,如下图所示。上面两幅图片,已经很详细的展示了损失是如何产生的,以及…

大家好,又见面了,我是你们的朋友全栈君。

上一节我们说了详细展示RNN的网络结构以及前向传播,在了解RNN的结构之后,如何训练RNN就是一个重要问题,训练模型就是更新模型的参数,也就是如何进行反向传播,也就意味着如何对参数进行求导。本篇内容就是详细介绍RNN的反向传播算法,即BPTT。


首先让我们来用动图来表示RNN的损失是如何产生的,以及如何进行反向传播,如下图所示。

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

上面两幅图片,已经很详细的展示了损失是如何产生的, 以及如何来对参数求导,这是忽略细节的RNN反向传播流程,我相信已经描述的非常清晰了。下图(来自trask)描述RNN详细结构中反向传播的过程。

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

有了清晰的反向传播的过程,我们接下来就需要进行理论的推到,由于符号较多,为了不至于混淆,根据下图,现标记符号如表格所示:

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

公式符号表
符号 含义

K

输入向量的大小(one-hot长度,也是词典大小)

T

输入的每一个序列的长度

H

隐藏层神经元的个数

X=\left \{ x_{1},x_{2},x_{3}....,x_{T} \right \}

样本集合

x_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻的输入

y_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻经过Softmax层的输出。

\hat{y}_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻输入样本的真实标签

L_{t}

t时刻的损失函数,使用交叉熵函数,

L_t=-\hat{y}_t^Tlog(y_t)

L

序列对应的损失函数:

L=\sum\limits_t^T L_t

RNN的反向传播是每处理完一个样本就需要对参数进行更新,因此当执行完一个序列之后,总的损失函数就是各个时刻所得的损失之和。

s_{t}\epsilon \mathbb{R}^{H\times 1}

t个时刻RNN隐藏层的输入。

h_{t}\epsilon \mathbb{R}^{H\times 1}

第t个时刻RNN隐藏层的输出。

z_{t}\epsilon \mathbb{R}^{H\times 1}

输出层的输入,即Softmax函数的输入

W\epsilon \mathbb{R}^{H\times K}

输入层与隐藏层之间的权重。

U\epsilon \mathbb{R}^{H\times H}

上一个时刻的隐藏层 与 当前时刻隐藏层之间的权值。

V\epsilon \mathbb{R}^{K\times H}

隐藏层与输出层之间的权重。

                                                                     \begin{matrix} \: \: \: \: \: \: \: \: \; \; \; \; \; \; \; \; \; \; \; \; \; s_t=Uh_{t-1}+Wx_t+b\\ \\ h_t=\sigma(s_t)\\ \\ \; \; \; \; z_t=Vh_t+c\\ \\ \; \; \; \; \; \; \; \; \; \; y_t=\mathrm{softmax}(z_t) \end{matrix}

我们对参数V,c求导比较方便,只有每一时刻的输出对应的损失与V,c相关,可以直接进行求导,即:

                                                  \frac{\partial L}{\partial V} =\sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial V} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial V} = \sum\limits_{t=1}^{\tau}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                 \frac{\partial L}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial c} = \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}

要对参数W,U,b进行更新,就不那么容易了,因为参数W,U,b虽是共享的,但是他们不只是对第t刻的输出做出了贡献,同样对t+1时刻隐藏层的输入s_{t+1}做出了贡献,因此在对W,U,b参数求导的时候,需要从后向前一步一步求导。

假设我们在对t时刻的参数W,U,b求导,我们利用链式法则可得出:

                                                                   \frac{\partial L}{\partial W}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial W}

                                                                    \frac{\partial L}{\partial U}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial U}

                                                                    \frac{\partial L}{\partial b}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial b}

我们发现对W,U,b进行求导的时候,都需要先求出\frac{\partial L}{\partial h_{t}},因此我们设:

                                                                         \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}

那么我们现在需要先求出\delta ^{t},则:

                                                                        \frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}=V^{T}(y_{t}-\hat{y}_{t})

                                                                   \begin{matrix} \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}=U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\delta ^{t+1}) \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\sigma ^{'}(z_{t+1}))\delta ^{t+1} \\ \\ =U^{T}diag(1-h_{t+1}^{2})\delta ^{t+1} \end{matrix}

注:在求解激活函数导数时,是将已知的部分求导之后,然后将它和激活函数导数部分进行哈达马乘积。激活函数的导数一般是和前面的进行哈达马乘积,这里的激活函数是双曲正切,用矩阵中对角线元素表示向量中各个值的导数,可以去掉哈达马乘积,转化为矩阵乘法。

则:

                                                                \begin{matrix} \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}\\ \\\: \; \; \; \; \; \; \; \; \; \; \; \; \; \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ \: \: \: \: \: \: \: \: \: \: \: \: \: \: \: \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1} (1-h_{t+1}^{2}) \end{matrix}

我们求得\delta ^{t},之后,便可以回到最初对参数的求导,因此有:

                                                           \frac{\partial L}{\partial W} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial W} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                            \frac{\partial L}{\partial b}= \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial b} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                           \frac{\partial L}{\partial U} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial U} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

有了各个参数导数之后,我们可以进行参数更新:

                                                          W^{'}=W-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                           U^{'}=U-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

                                                          V^{'}=V-\theta \sum\limits_{t=1}^{T}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                           b^{'}=b-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                          c^{'}=c- \theta \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}


参考:

刘建平《循环神经网络(RNN)模型与前向反向传播算法

李弘毅老师《深度学习》

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152338.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • redis默认端口为什么是6379_redis 端口

    redis默认端口为什么是6379_redis 端口1、如果开了redis服务,先将服务关闭2、在window上找到redis的安装目录,修改redis.windows.conf文件,在里面将默认端口改为你想要的端口号3、将redis.windows.conf文件直接拖入redis-server.exe,弹出窗口

    2022年9月18日
    3
  • word在试图打开文件时遇到错误[通俗易懂]

    word在试图打开文件时遇到错误[通俗易懂]昨天晚上在台式机上写了个word文档,发到邮箱。今天在邮箱上下载后就打不开了,出现“word在试图打开文件时遇到错误”。让人很郁闷,因为两台电脑上装的都是word2016.。。。看到网上说的是用WPS打开,之后再另存为,可是我早已没有了WPS。突然灵光一现,我点击了该文档的属性:有个解除锁定,我在上面划勾,点击确定后,再次打开,竟然就打开了~~~~开心~20170825周

    2022年4月28日
    81
  • webservice安全策略[通俗易懂]

    webservice安全策略[通俗易懂]前些日子公司的应用要和合作方对接,我参与了webservice这块的工作,在访问量很小的情况下基本上完成了功能,但安全这块没有找到合适的方案,所以自己做了些旁门左道的设想,不一定合理和完善,希望能起个

    2022年7月2日
    32
  • HDU 4825 Xor Sum 字典树+位运算

    HDU 4825 Xor Sum 字典树+位运算

    2021年12月5日
    52
  • 关于大数据,云计算,物联网的概述正确的是_物联网应用领域

    关于大数据,云计算,物联网的概述正确的是_物联网应用领域1、大数据时代  以大数据、物联网和云计算为标志的第三次信息化浪潮开始,大数据时代全面开启。大数据发展主要经历了三个历程。2、大数据的概念  关于什么是大数据”这个问题,大家比较认可关于大数据的“4V”说法。大数据的4个“V”,或者说是大数据的4个特点,包含4个层面:数据量大(Volume).数据类型繁多(Variety).处理速度快(Velocity)和价值密度低(Value)。3、…

    2022年10月7日
    2
  • JSF之经常使用注解

    JSF之经常使用注解

    2022年1月22日
    47

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号