详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」上一节我们说了详细展示RNN的网络结构以及前向传播,在了解RNN的结构之后,如何训练RNN就是一个重要问题,训练模型就是更新模型的参数,也就是如何进行反向传播,也就意味着如何对参数进行求导。本篇内容就是详细介绍RNN的反向传播算法,即BPTT。首先让我们来用动图来表示RNN的损失是如何产生的,以及如何进行反向传播,如下图所示。上面两幅图片,已经很详细的展示了损失是如何产生的,以及…

大家好,又见面了,我是你们的朋友全栈君。

上一节我们说了详细展示RNN的网络结构以及前向传播,在了解RNN的结构之后,如何训练RNN就是一个重要问题,训练模型就是更新模型的参数,也就是如何进行反向传播,也就意味着如何对参数进行求导。本篇内容就是详细介绍RNN的反向传播算法,即BPTT。


首先让我们来用动图来表示RNN的损失是如何产生的,以及如何进行反向传播,如下图所示。

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

上面两幅图片,已经很详细的展示了损失是如何产生的, 以及如何来对参数求导,这是忽略细节的RNN反向传播流程,我相信已经描述的非常清晰了。下图(来自trask)描述RNN详细结构中反向传播的过程。

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

有了清晰的反向传播的过程,我们接下来就需要进行理论的推到,由于符号较多,为了不至于混淆,根据下图,现标记符号如表格所示:

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

公式符号表
符号 含义

K

输入向量的大小(one-hot长度,也是词典大小)

T

输入的每一个序列的长度

H

隐藏层神经元的个数

X=\left \{ x_{1},x_{2},x_{3}....,x_{T} \right \}

样本集合

x_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻的输入

y_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻经过Softmax层的输出。

\hat{y}_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻输入样本的真实标签

L_{t}

t时刻的损失函数,使用交叉熵函数,

L_t=-\hat{y}_t^Tlog(y_t)

L

序列对应的损失函数:

L=\sum\limits_t^T L_t

RNN的反向传播是每处理完一个样本就需要对参数进行更新,因此当执行完一个序列之后,总的损失函数就是各个时刻所得的损失之和。

s_{t}\epsilon \mathbb{R}^{H\times 1}

t个时刻RNN隐藏层的输入。

h_{t}\epsilon \mathbb{R}^{H\times 1}

第t个时刻RNN隐藏层的输出。

z_{t}\epsilon \mathbb{R}^{H\times 1}

输出层的输入,即Softmax函数的输入

W\epsilon \mathbb{R}^{H\times K}

输入层与隐藏层之间的权重。

U\epsilon \mathbb{R}^{H\times H}

上一个时刻的隐藏层 与 当前时刻隐藏层之间的权值。

V\epsilon \mathbb{R}^{K\times H}

隐藏层与输出层之间的权重。

                                                                     \begin{matrix} \: \: \: \: \: \: \: \: \; \; \; \; \; \; \; \; \; \; \; \; \; s_t=Uh_{t-1}+Wx_t+b\\ \\ h_t=\sigma(s_t)\\ \\ \; \; \; \; z_t=Vh_t+c\\ \\ \; \; \; \; \; \; \; \; \; \; y_t=\mathrm{softmax}(z_t) \end{matrix}

我们对参数V,c求导比较方便,只有每一时刻的输出对应的损失与V,c相关,可以直接进行求导,即:

                                                  \frac{\partial L}{\partial V} =\sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial V} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial V} = \sum\limits_{t=1}^{\tau}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                 \frac{\partial L}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial c} = \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}

要对参数W,U,b进行更新,就不那么容易了,因为参数W,U,b虽是共享的,但是他们不只是对第t刻的输出做出了贡献,同样对t+1时刻隐藏层的输入s_{t+1}做出了贡献,因此在对W,U,b参数求导的时候,需要从后向前一步一步求导。

假设我们在对t时刻的参数W,U,b求导,我们利用链式法则可得出:

                                                                   \frac{\partial L}{\partial W}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial W}

                                                                    \frac{\partial L}{\partial U}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial U}

                                                                    \frac{\partial L}{\partial b}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial b}

我们发现对W,U,b进行求导的时候,都需要先求出\frac{\partial L}{\partial h_{t}},因此我们设:

                                                                         \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}

那么我们现在需要先求出\delta ^{t},则:

                                                                        \frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}=V^{T}(y_{t}-\hat{y}_{t})

                                                                   \begin{matrix} \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}=U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\delta ^{t+1}) \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\sigma ^{'}(z_{t+1}))\delta ^{t+1} \\ \\ =U^{T}diag(1-h_{t+1}^{2})\delta ^{t+1} \end{matrix}

注:在求解激活函数导数时,是将已知的部分求导之后,然后将它和激活函数导数部分进行哈达马乘积。激活函数的导数一般是和前面的进行哈达马乘积,这里的激活函数是双曲正切,用矩阵中对角线元素表示向量中各个值的导数,可以去掉哈达马乘积,转化为矩阵乘法。

则:

                                                                \begin{matrix} \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}\\ \\\: \; \; \; \; \; \; \; \; \; \; \; \; \; \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ \: \: \: \: \: \: \: \: \: \: \: \: \: \: \: \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1} (1-h_{t+1}^{2}) \end{matrix}

我们求得\delta ^{t},之后,便可以回到最初对参数的求导,因此有:

                                                           \frac{\partial L}{\partial W} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial W} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                            \frac{\partial L}{\partial b}= \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial b} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                           \frac{\partial L}{\partial U} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial U} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

有了各个参数导数之后,我们可以进行参数更新:

                                                          W^{'}=W-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                           U^{'}=U-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

                                                          V^{'}=V-\theta \sum\limits_{t=1}^{T}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                           b^{'}=b-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                          c^{'}=c- \theta \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}


参考:

刘建平《循环神经网络(RNN)模型与前向反向传播算法

李弘毅老师《深度学习》

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152338.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • log4j-over-slf4j与slf4j-log4j12共存stack overflow异常分析

    log4j-over-slf4j与slf4j-log4j12共存stack overflow异常分析log4j over slf4j 和 slf4j log4j12 是跟 java 日志系统相关的两个 jar 包 当它们同时出现在 classpath 下时 就可能会引起堆栈溢出异常 先大致梳理了一下现有 Java 日志体系接口 然后仔细分析了这种异常出现的原因 最后重现异常并展示了详细的调用过程

    2026年1月18日
    1
  • Linux下安装Tomcat 10

    Linux下安装Tomcat 10(Linux)Deepin下安装Tomcat101)在官网下载tar.gz包2)解压到目录(这里‘用户名’换成你自己的)sudotar-zxvf/home/用户名/Downloads/apache-tomcat-10.0.0.tar.gz-C/usr/local3)重命名sudomv/usr/local/apache-tomcat-10.0.0/usr/local/Tomcat4)测试sudo/usr/local/Tomcat/bin/startup.sh若失败则

    2022年6月2日
    45
  • 电路板维修入门教程视频_电路板坏了去哪里维修

    电路板维修入门教程视频_电路板坏了去哪里维修(一)          电容篇  1、电容在电路中一般用“C”加数字表示(如C25表示编号为25的电容)。电容是由两片金属膜紧靠,中间用绝缘材料隔开而组成的元件。电容的特性主要是隔直流通交流。电容容量的大小就是表示能贮存电能的大小,电容对交流信号的阻碍作用称为容抗,它与交流信号的频率和电容量有关。容抗XC=1/2πfc(f表示交流信号的

    2022年8月29日
    5
  • Pytest(1)安装与入门[通俗易懂]

    Pytest(1)安装与入门[通俗易懂]pytest介绍pytest是python的一种单元测试框架,与python自带的unittest测试框架类似,但是比unittest框架使用起来更简洁,效率更高。根据pytest的官方网站介绍,它

    2022年7月28日
    7
  • Navicat Premium 中文版注册码

    Navicat Premium 中文版注册码NAVN-U6QE-6PX7-44K5NAVI-WVK6-ZYW4-LQYUNAVJ-5DOO-FCAA-PHMZ经测试,Nacicat版本是10.0.11(黄色版本)可以使用第一个注册码关注我的技术公众号,每天都有优质技术文章推送。微信扫一扫下方二维码即可关注:…

    2022年9月25日
    2
  • Java 并发专题 : Timer的缺陷 用ScheduledExecutorService替代「建议收藏」

    Java 并发专题 : Timer的缺陷 用ScheduledExecutorService替代「建议收藏」继续并发,上篇博客对于ScheduledThreadPoolExecutor没有进行介绍,说过会和Timer一直单独写一篇Blog.1、Timer管理延时任务的缺陷a、以前在项目中也经常使用定时器,比如每隔一段时间清理项目中的一些垃圾文件,每个一段时间进行数据清洗;然而Timer是存在一些缺陷的,因为Timer在执行定时任务时只会创建一个线程,所以如果存在多个任务,且任务时间过长,超过了两

    2022年6月2日
    49

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号