LSTM详解 反向传播公式推导[通俗易懂]

LSTM详解 反向传播公式推导[通俗易懂]部分参考:http://colah.github.io/posts/2015-08-Understanding-LSTMs/1.结构1.1比较RNN:LSTM:其中的notation:1.2核心思想:LSTM在原来的隐藏层上增加了一个ThekeytoLSTMsisthecellstate,thehorizont…

大家好,又见面了,我是你们的朋友全栈君。

部分参考: http://colah.github.io/posts/2015-08-Understanding-LSTMs/

1. 结构

1.1 比较

RNN:

LSTM详解 反向传播公式推导[通俗易懂]

LSTM:

LSTM详解 反向传播公式推导[通俗易懂]

其中的notation:

LSTM详解 反向传播公式推导[通俗易懂]

1.2 核心思想:

LSTM在原来的隐藏层上增加了一个

The key to LSTMs is the cell state, the horizontal line running through the top of the diagram.

LSTM详解 反向传播公式推导[通俗易懂]

LSTM详解 反向传播公式推导[通俗易懂]

  1. LSTM比RNN多了一个细胞状态,就是最上面一条线,像一个传送带,它让信息在这条线上传播而不改变信息。

    The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along it unchanged.

    LSTM详解 反向传播公式推导[通俗易懂]

  2. LSTM可以自己增加或移除信息,通过“门”的结构控制。

    “门”选择性地让信息是否通过,“门”包括一个sigmoid层和按元素乘。如下图:

    LSTM详解 反向传播公式推导[通俗易懂]

    sigmoid层输出0-1的值,表示让多少信息通过,1表示让所有的信息都通过。

    一个LSTM单元有3个门。

2. 流程

上面一条线 C t C_t Ct是细胞状态,下面的 h t h_t ht是隐藏状态 。

其实看多少个圆圈、黄框就知道有哪些操作了。

三个sigmoid层是三个门:忘记门、输入门、输出门。
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) C t ~ = t a n h ( W C ⋅ [ h t − 1 , x t ] + b C ) C t = f t ∗ C t − 1 + i t ∗ C t ~ o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) h t = o t ∗ t a n h ( C t ) f_t=\sigma(W_f \centerdot [h_{t-1},x_t]+b_f)\\ i_t=\sigma(W_i \centerdot [h_{t-1},x_t]+b_i)\\ \tilde{C_t}=tanh(W_C \centerdot [h_{t-1},x_t]+b_C)\\ C_t=f_t*C_{t-1}+i_t*\tilde{C_t}\\ o_t=\sigma(W_o \centerdot [h_{t-1},x_t]+b_o)\\ h_t=o_t*tanh(C_t) ft=σ(Wf[ht1,xt]+bf)it=σ(Wi[ht1,xt]+bi)Ct~=tanh(WC[ht1,xt]+bC)Ct=ftCt1+itCt~ot=σ(Wo[ht1,xt]+bo)ht=ottanh(Ct)
LSTM详解 反向传播公式推导[通俗易懂]

2.1 忘记门:扔掉信息(细胞状态)

第一步是决定从细胞状态里扔掉什么信息(也就是保留多少信息)。将上一步细胞状态中的信息选择性的遗忘 。

实现方式:通过sigmoid层实现的“忘记门”。以上一步的 h t − 1 h_{t-1} ht1和这一步的 x t x_t xt作为输入,然后为 C t − 1 C_{t-1} Ct1里的每个数字输出一个0-1间的值,表示保留多少信息(1代表完全保留,0表示完全舍弃),然后与 C t − 1 C_{t-1} Ct1乘。

例子:让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的类别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。
例如,他今天有事,所以我… 当处理到‘’我‘’的时候选择性的忘记前面的’他’,或者说减小这个词对后面词的作用。

LSTM详解 反向传播公式推导[通俗易懂]

2.2 输入层门:存储信息(细胞状态)

第二步是决定在细胞状态里存什么。将新的信息选择性的记录到细胞状态中。

实现方式:包含两部分,1. sigmoid层(输入门层)决定我们要更新什么值;2. tanh层创建一个候选值向量 C t ~ \tilde{C_t} Ct~,将会被增加到细胞状态中。 我们将会在下一步把这两个结合起来更新细胞状态。

例子:在我们语言模型的例子中,我们希望增加新的主语的类别到细胞状态中,来替代旧的需要忘记的主语。 例如:他今天有事,所以我… 当处理到‘’我‘’这个词的时候,就会把主语我更新到细胞中去。

LSTM详解 反向传播公式推导[通俗易懂]

2.3 更新细胞状态(细胞状态)

更新旧的细胞状态

实现方式: C t = f t ∗ C t − 1 + i t ∗ C t ~ C_t=f_t*C_{t-1}+i_t*\tilde{C_t} Ct=ftCt1+itCt~ f t f_t ft表示保留上一次的多少信息, i t i_t it表示更新哪些值, C t ~ \tilde{C_t} Ct~表示新的候选值。候选值被要更新多少(即 i t i_t it)放缩。

这一步我们真正实现了移除哪些旧的信息(比如一句话中上一句的主语),增加哪些新信息。

LSTM详解 反向传播公式推导[通俗易懂]

2.4 输出层门:输出(隐藏状态)

最后,我们要决定作出什么样的预测。

实现方式:1. 我们通过sigmoid层(输出层门)来决定输出新的细胞状态的哪些部分;2. 然后我们将细胞状态通过tanh层(使值在-1~1之间),然后与sigmoid层的输出相乘。

所以我们只输出我们想输出的部分。

例子:在语言模型的例子中,因为它就看到了一个 代词,可能需要输出与一个 动词 相关的信息。例如,可能输出是否代词是单数还是复数,这样如果是动词的话,我们也知道动词需要进行的词形变化。
例如:上面的例子,当处理到‘’我‘’这个词的时候,可以预测下一个词,是动词的可能性较大,而且是第一人称。 会把前面的信息保存到隐层中去。

LSTM详解 反向传播公式推导[通俗易懂]

3. 反向传播

LSTM详解 反向传播公式推导[通俗易懂]

如图上的5个不同的阶段反向传播误差。

1. 维度

先介绍下各个变量的维度, h t h_t ht 的维度是黄框里隐藏层神经元的个数,记为d,即矩阵 W ∗ W_* W 的行数(因为 W j i W_{ji} Wji是输入层的第 i i i个神经元指向隐藏层第 j j j个神经元的参数), x t x_t xt的维度记为n,则 [ h t − 1 x t ] [h_{t-1}\quad x_t] [ht1xt]的维度是 d + n d+n d+n,矩阵的维度都是 d ∗ ( d + n ) d*(d+n) d(d+n),其他的向量维度都是 d ∗ 1 d*1 d1。(为了表示、更新方便,我们将bias放到矩阵里)

W f W_f Wf举例:
LSTM详解 反向传播公式推导[通俗易懂]

同理:
LSTM详解 反向传播公式推导[通俗易懂]
合并为一个矩阵就是:
LSTM详解 反向传播公式推导[通俗易懂]

2.一些符号

⊙ \odot 是element-wise乘,即按元素乘。其他的为正常的矩阵乘。

为了表示向量和矩阵的元素,我们把时间写在上标。

δ z t \delta z^t δzt表示 E t E^t Et z t z^t zt的偏导。

⊗ ​ \otimes​ 表示外积,即列向量乘以行向量

3.两点疑惑

  1. 3.2中 δ C t + = δ h t ⊙ o t ⊙ [ 1 − t a n h 2 ( C t ) ] \delta C^t+=\delta h^t \odot o^t \odot [1-tanh^2(C^t)] δCt+=δhtot[1tanh2(Ct)] 我还没想明白
  2. 3.5中下划线的是 h t − 1 h^{t-1} ht1的函数 ,但 C i t − 1 C^{t-1}_i Cit1 也是 h t − 1 h^{t-1} ht1的函数,不知道为什么不算 。

3.1 δ h t \delta h^t δht

我们首先假设 ∂ E t ∂ h t = δ h t \frac{\partial E^t}{\partial h^t}=\delta h^t htEt=δht ,这里的 E t E^t Et指的是t时刻的误差,对每个时刻的误差都要计算一次。

LSTM详解 反向传播公式推导[通俗易懂]

3.2 δ o t 、 δ C t \delta o^t 、\delta C^t δotδCt

Forward: h t = o t ⊙ t a n h ( C t ) h^t=o^t \odot tanh(C^t) ht=ottanh(Ct)

已知: δ h t \delta h^t δht

求: δ o t 、 δ C t \delta o^t 、\delta C^t δotδCt

解:由于
LSTM详解 反向传播公式推导[通俗易懂]
所以
LSTM详解 反向传播公式推导[通俗易懂]

但是 这里提到 δ C t + = δ h t ⊙ o t ⊙ [ 1 − t a n h 2 ( C t ) ] \delta C^t+=\delta h^t \odot o^t \odot [1-tanh^2(C^t)] δCt+=δhtot[1tanh2(Ct)] 我还没想明白

3.3 δ f t 、 δ i t 、 δ C t ~ 、 δ C t − 1 \delta f_t、\delta i_t、\delta \tilde{C_t}、\delta C^{t-1} δftδitδCt~δCt1

Forward: C t = f t ⊙ C t − 1 + i t ⊙ C t ~ C^t=f^t \odot C^{t-1}+i^t \odot \tilde{C^t} Ct=ftCt1+itCt~

已知: δ o t 、 δ C t \delta o^t 、\delta C^t δotδCt

求: δ f t 、 δ i t 、 δ C t ~ 、 δ C t − 1 \delta f_t、\delta i_t、\delta \tilde{C_t}、\delta C^{t-1} δftδitδCt~δCt1

解:因为
LSTM详解 反向传播公式推导[通俗易懂]
所以:
LSTM详解 反向传播公式推导[通俗易懂]

3.4 δ W f 、 δ W i 、 δ W c 、 δ W o \delta W_f、 \delta W_i 、\delta W_c 、\delta W_o δWfδWiδWcδWo

Forward: f t = W f ⋅ s t , i t = W i ⋅ s t , o t = W o ⋅ s t , C t ~ = W C ⋅ s t f^t=W_f \cdot s^t ,i^t=W_i \cdot s^t ,o^t=W_o \cdot s^t ,\tilde{C^t}=W_C \cdot s^t \quad ft=Wfst,it=Wist,ot=Wost,Ct~=WCst

已知: δ f t 、 δ i t 、 δ C t ~ 、 δ C t − 1 \delta f_t、\delta i_t、\delta \tilde{C_t}、\delta C^{t-1} δftδitδCt~δCt1

求: δ W f 、 δ W i 、 δ W c 、 δ W o \delta W_f、 \delta W_i 、\delta W_c 、\delta W_o δWfδWiδWcδWo

解:由于
LSTM详解 反向传播公式推导[通俗易懂]

所以:
LSTM详解 反向传播公式推导[通俗易懂]

合并在一起就是:
LSTM详解 反向传播公式推导[通俗易懂]

3.5 δ h t − 1 \delta h_{t-1} δht1

Forward:
LSTM详解 反向传播公式推导[通俗易懂]
下划线的是 h t − 1 ​ h^{t-1}​ ht1的函数 ,但 C i t − 1 ​ C^{t-1}_i​ Cit1 也是 h t − 1 ​ h^{t-1}​ ht1的函数,不知道为什么不算 。

已知: δ f t 、 δ i t 、 δ C t ~ 、 δ C t − 1 \delta f_t、\delta i_t、\delta \tilde{C_t}、\delta C^{t-1} δftδitδCt~δCt1

求: δ h t − 1 \delta h_{t-1} δht1

解:
LSTM详解 反向传播公式推导[通俗易懂]
式1是Forward里第一个式子对 o i t 和 c i t o^t_i和c^t_i oitcit求导;式2是Forward里第二个式子对 f i t , i i t , C t ~ f^t_i,i^t_i,\tilde{C^t} fit,iit,Ct~求导。

然后看Forward里的后几个式子,对 h t − 1 h^{t-1} ht1求导:
LSTM详解 反向传播公式推导[通俗易懂]
所以(*)式用矩阵计算则为:
LSTM详解 反向传播公式推导[通俗易懂]

3.6 总梯度

∂ E ∂ W = ∑ t = 0 T ∂ E t ∂ W \frac{\partial E}{\partial W}=\sum^T_{t=0} \frac{\partial E^t}{\partial W} WE=t=0TWEt

4.变种

GRU

It combines the forget and input gates into a single “update gate.” It also merges the cell state and hidden state, and makes some other changes. The resulting model is simpler than standard LSTM models, and has been growing increasingly popular.

LSTM详解 反向传播公式推导[通俗易懂]

add peephole

LSTM详解 反向传播公式推导[通俗易懂]

use coupled forget

Another variation is to use coupled forget and input gates. Instead of separately deciding what to forget and what we should add new information to, we make those decisions together. We only forget when we’re going to input something in its place. We only input new values to the state when we forget something older.

LSTM详解 反向传播公式推导[通俗易懂]

其他

These are only a few of the most notable LSTM variants. There are lots of others, like Depth Gated RNNs by Yao, et al. (2015). There’s also some completely different approach to tackling long-term dependencies, like Clockwork RNNs by Koutnik, et al. (2014).

Which of these variants is best? Do the differences matter? Greff, et al. (2015) do a nice comparison of popular variants, finding that they’re all about the same. Jozefowicz, et al. (2015)tested more than ten thousand RNN architectures, finding some that worked better than LSTMs on certain tasks.

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/142951.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • RPC是什么? (学习笔记)

    RPC是什么? (学习笔记)什么是RPC?RPC全称RemoteProcedureCall,即远程过程调用,就是要像调用本地的函数一样去调远程函数,屏蔽远程调用的复杂性。为什么需要RPC?微服务、分布式应用的开发越来越

    2022年8月3日
    10
  • Html元素的scrollWidth和scrollHeight详解 .[通俗易懂]

    Html元素的scrollWidth和scrollHeight详解 .[通俗易懂]上网搜了一下scrollWidth和scrollHeight,大部分都是转帖,也没有具体说清楚,这两个属性值是什么,也没有图。索性自己测试一下,包含的浏览器有IE6,IE7,IE8,IE9,Firefox,Chrome,Opera,Safari,顺便把测试的截图也发上来,这样大家看着也明白。一、scrollWidth首先,我们先上MSDN上查一下scroll

    2022年7月23日
    14
  • IP地址范围怎么算_ip地址数目怎么算

    IP地址范围怎么算_ip地址数目怎么算1、如果掩码、IP等信息如下:2、我们可以看到,子网掩码为255.255.255.240,因为0-255有256个数字,所以256-240=16。也就是这个网段有16个IP地址。3、我们现在使用的IP地址是什么,或者是网关,最后的一个数字就好。IP是203,网关是193。4、找到IP段就能判断可用IP是多少。这时因为每个IP段都是由四部分组成,分别是网络号、网关、可用IP、广播号。…

    2022年10月19日
    5
  • 【Linux】面试题(2021最新版)

    【Linux】面试题(2021最新版)Linux的体系结构Linux-查找特定文件Linux-对日志内容做统计Linux-批量替换文件内容

    2022年6月3日
    37
  • Java就业前景和薪资状况,究竟怎么样呢?

    Java就业前景和薪资状况,究竟怎么样呢?在未来5年内,软件人才的需求将远大于供给。Java软件工程师是目前国际高端计算机领域就业薪资较高的一类软件工程师。看到这里有人问了:那Java的现实就业前景和薪资状况,究竟怎么样呢?1、Java工程师就业前景在美国、加拿大、澳大利亚、新加坡等发达国家和中等发达国家,Java软件工程师年薪均在4—15万美金,而在国内,Java软件工程师也有极好的工作机会和很高的薪水。一般情况下的Java软件工程师是分四个等级,从软件技术员到助理软件工程师,再到软件工程师,最后成为高级软件工程师。根据IDC的统计数字,

    2022年7月8日
    27
  • 计算机组成原理寄存器初始化,计算机组成原理寄存器实验

    计算机组成原理寄存器初始化,计算机组成原理寄存器实验

    2021年8月16日
    54

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号