RNN训练算法BPTT介绍

RNN训练算法BPTT介绍 本篇文章第一部分翻译自:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/,英文好的朋友可以直接看原文。最近看到RNN,先是困惑于怎样实现隐藏层的互联,搞明白之后又不太明白如何使用BPTT进…

大家好,又见面了,我是你们的朋友全栈君。

 

本篇文章第一部分翻译自:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/,英文好的朋友可以直接看原文。

最近看到RNN,先是困惑于怎样实现隐藏层的互联,搞明白之后又不太明白如何使用BPTT进行训练,在网上找资源发现本篇博客介绍较为详细易懂,自己翻译了一遍,以下:

RNN教程,第3部分,通过时间反向传播(BPTT)和梯度消失

这是RNN教程的第三部分。

在本教程的前面部分我们从头开始实现了一个RNN网络,但是没有探究实现BPTT计算梯度的细节。在这部分我们将给出BPTT的简要概述并且解释它和传统反向传播算法的区别。随后我们将致力于理解梯度消失问题(vanishing gradient problem),这个问题促成了LSTMs和GRUs的发展,在NLP(和其他领域),它们是当前最受欢迎和最为强大的模型中的两种。梯度消失问题最早于1991年有Sepp Hochreiter发现,最近由于深度结构的应用增多而重新受到关注。

如果想完全理解这部分内容,我建议你对偏导数和基本的反向传播工作很熟悉。如果你还不熟悉,你可以从【文中提供了三个地址】找到好的教程,它们随着难度的上升而排序。

Backpropagation Through TIme(BPTT)

我们先快速回顾一下RNN的等式。注意到这里有一个小变化,符号o变成了\widehat{y}。这是为了和我参考的一些文献保持一致。

RNN训练算法BPTT介绍

我们同时定义我们的损失函数(或者称为误差)为交叉熵损失,由以下公式给出:

RNN训练算法BPTT介绍

这里{y}_t是t时刻的正确单词,\widehat{y}_t是网络的预测。典型的,我们将完整的序列(句子)是为一个训练实例,所以总的误差是各个时间点(单词)误差的和。

RNN训练算法BPTT介绍

我们的目的是计算误差关于参数U,V和W的梯度并通过随机梯度下降(SGD)来学习好的参数。正如我们计算了误差的和,我们也将一个训练实例各个时间地啊你的梯度做一个求和:\frac{\partial E}{\partial W} = \sum\limits_{t} \frac{\partial E_t}{\partial W} 。

我们使用链式求导来计算这些导数。这是从误差开始后应用反向传播算法。在这篇文章的剩余部分我们将使用 E_3作为例子,这只是为了用一个实际的数来做推导。

\begin{aligned}  \frac{\partial E_3}{\partial V} &=\frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial V}\\  &=\frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial z_3}\frac{\partial z_3}{\partial V}\\  &=(\hat{y}_3 - y_3) \otimes s_3 \\  \end{aligned}

在上面的式子中,z_3 =Vs_3,同时\otimes表示两个向量的外积运算。如果上面讲的你跟不上也不用担心,我跳过了一些步骤,你可以自己尝试计算这些导数(这是一个很好的锻炼!)。我想从上面式子中得到的是\frac{\partial E_3}{\partial V}的计算仅仅依赖于当前时间点的数值\hat{y}_3, y_3, s_3。如果你掌握着这些,计算误差关于V的导数就仅仅是一个简单的矩阵乘法。

但是对于\frac{\partial E_3}{\partial W}(和U)的情况却是不同的。我们列出链式法则来一探究竟,与上面类似:

\begin{aligned}  \frac{\partial E_3}{\partial W} &= \frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial s_3}\frac{\partial s_3}{\partial W}\\  \end{aligned}

现在应该注意到的是s_3 = \tanh(Ux_t + Ws_2)依赖于s_2,而s_2又依赖于W和s_1,以此类推。如果我们计算关于W的导数我们不能简单地将s_2视为常量!我们需要再次使用链式法则,我们最终获得的表达式为:

\begin{aligned}  \frac{\partial E_3}{\partial W} &= \sum\limits_{k=0}^{3} \frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial W}\\  \end{aligned}

我们将每个时间点对梯度的贡献求和。话句话说,由于在到达我们所关心的输出的过程中的每一步计算中都用了W,我们需要从t=3开始在网络中的每一个路径反向传播梯度直到t=0。

RNN训练算法BPTT介绍

注意到这和我们在深度前向神经网络中使用的标准反向传播算法是一样的。最主要的区别在于我们计算了关于W每个时间点上的梯度并将它们求和。传统神经网络中我们不会在层间分享参数,所以也不用做任何求和。但是在我看来BPTT不过是标准反向传播在没展开的RNN的一个有趣的名字。类似反向传播你可以定义一个向后传播的δ矢量,例如:\delta_2^{(3)} = \frac{\partial E_3}{\partial z_2} =\frac{\partial E_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial z_2},这里z_2 = Ux_2+ Ws_1。然后应用相同的方程式。

一个简单的BPTT实现类似于下面的代码:

RNN训练算法BPTT介绍

翻译结束,原文后续部分探讨梯度消失。

推到一下上面的公式:

RNN训练算法BPTT介绍

部分参考;

http://blog.sina.cn/dpool/blog/s/blog_6e32babb0102y3u7.html

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152328.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 浙江python信息技术教材_人工智能、Python…浙江省三到九年级将使用信息技术新修订教材…

    浙江python信息技术教材_人工智能、Python…浙江省三到九年级将使用信息技术新修订教材…浙江省教研室相关负责人表示,目前根据现行的高中教材,对小学、初中的老教材进行了修订,新教材将于今年9月投入使用,最新的线上教师培训也刚刚结束。扣哒世界作为全球最大的中小学人工智能和Python代码编程教学平台,已经从2019年开始系统支持浙江省中小学人工智能和Python代码编程师资培训。根据浙江省最新的教材目录,从小学三年级一直到九年级,内容都有不同程度的调整。三年级新增了“信息社会”和“网络生…

    2022年5月16日
    35
  • 王思聪新浪微博微博_kimi乔任梁王思聪

    王思聪新浪微博微博_kimi乔任梁王思聪作者|天使不投资人本文经授权转载自虎嗅APP(ID:huxiu_com)iG夺冠了!iG夺冠了!——11月3日,社交媒体成为了年轻人欢乐的海洋,微博尤甚。根本不知道LOL、也不知道iG是什么的叔叔阿姨们,对这次刷屏一点都不反感,毕竟IG老板,人称“校长”的王思聪,为了庆祝自家战队创造历史,在11月6日发起了一场豪气抽奖:从参与人数就可以隔着屏幕感受到一万元奖金的巨大…

    2022年8月30日
    3
  • mac clion激活码破解方法

    mac clion激活码破解方法,https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月15日
    411
  • 微信支付与支付宝钱包的竞争分析

    微信支付与支付宝钱包的竞争分析微信支付与支付宝钱包的竞争分析NO1:2013年8月,微信5.0上线,其中附加了一个功能叫做微信支付,当时的微信用户已经超过了4亿,活跃用户1.94亿,估计不少人在看微信支付同支付老大哥支付包的大战。说起微信支付和支付宝的大战,先来说说他们背景,微信支付是社交软件巨头腾讯公司旗下的微信中的附加功能,而支付宝是电商巨头阿里巴巴旗下的支付理财软件。两家都有超过2万的顶级互联网员工,兵强马壮…

    2022年5月14日
    58
  • excel 导出json_导出的数据格式不对

    excel 导出json_导出的数据格式不对json格式数据转Excel导出的两种方法第一种table格式数据直接转Excel:但是用这种方式会出现一种问题,就是当你的table有分页的情况下,只能抓取当前分页的数据。拿到表格的id就可以

    2022年8月4日
    20
  • java 静态变量 存储_java中,类的静态变量如果是对象,该对象将存储在内存的哪个区域?…

    java 静态变量 存储_java中,类的静态变量如果是对象,该对象将存储在内存的哪个区域?…静态变量所引用的实例位于Java堆或运行时常量池。Java字节码与Native机器码不同,字节码是运行在JVM这一平台上的,字节码在被解释的过程中,具体的执行方式因JVM的不同实现而不同,但是对于JVM来说,它的各种不同实现都必须要遵循Java虚拟机规范。JVM的运行时数据区包含以下部分:1、PC寄存器每一条Java虚拟机线程都有自己的PC寄存器,如果正在被线程执行的当前方法不是native的,那…

    2022年4月28日
    44

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号