RNN学习笔记(一)-简介及BPTT RTRL及Hybrid(FP/BPTT)算法[通俗易懂]

RNN学习笔记(一)-简介及BPTT RTRL及Hybrid(FP/BPTT)算法[通俗易懂]RNN网络的学习算法-BPTT笔记本Markdown编辑器使用StackEdit修改而来,用它写博客,将会带来全新的体验哦:Markdown和扩展Markdown简洁的语法代码块高亮图片链接和图片上传LaTex数学公式UML序列图和流程图离线写博客导入导出Markdown文件丰富的快捷键快捷键加粗Ctrl+B斜体Ctrl+I引

大家好,又见面了,我是你们的朋友全栈君。

RNN学习笔记(一)-简介及BPTT RTRL及Hybrid(FP/BPTT)算法

本文假设读者已经熟悉了常规的神经网络,并且了解了BP算法,如果还不了解的,参见UFIDL的教程。
1.RNN结构
2.符号定义
3.网络unrolled及公式推导
4.BPTT
5.RTRL
6.Hybrid(FP/BPTT)
7.参考文献


1.RNN结构

如下图1是一个最简单的RNN:
图1
其中集合 I

m
个外部输入节点,左下角的 U 为前一时刻的隐层输出节点,U中的节点数为

n
,并假定U中所有节点的输出都参与到下一时刻的输入。

2.符号定义

定义:
xi(t) : t 时刻第

i
个输入节点的输出值,且 iIU
sk(t) : t 时刻第

k
个隐层节点的输出值,且 kU
yk(t) : t 时刻第

k
个输出层节点的输出值,且 kU
dk(t) : t 时刻隐层第

k
个节点的期望输出(即训练数据)
wli :第 i 个输入到第

l
个隐层节点的权重,其中 iIlU
wlk :第 k 个输入到第

l
个隐层节点的权重,其中 klU
τ :假定网络的起始时刻为 t0 ,当前时刻为 t

t[t0,t)
, τ(t,t]
yk(τ) : τ 时刻第 k 个输出节点的输出值,且

kU,τ(t0,t]
,对于所有的 τ 而言,其实有 yk(τ)=yk(τ) ,这里之所以引入新符号,是为了避免求导运算时混淆1

再来是一组等式定义:
sk(τ+1)=wx(τ)
ek(t)=dk(t)yk(t)
J(τ)=kUek(t)
Jtotal(t,t)=τ=t+1tJ(τ),t[t0,t)
ϵk(τ;F)=Fyk(τ)
ek(τ;F)=Fyk(τ)
δk(τ;F)=Fsk(τ)
pkij(τ)=yk(τ)wij
因为假定 F 只与

yk(τ),τ(t,t]
显式相关,所以,当 τt 时, ek(τ;F)=0
由于 F 是任意与

yk(t)
相关的函数,实际应用中,可以取
F=J(τ)F=Jtotal(t,t) 或其它函数。
因为初始状态的输出 yk(t0) 为预设值,与 w 之间不存在函数关系,所以当

τ=t0
时, pkij(t0)=0

3.网络unrolled及公式推导

将网络按时间展开:
这里写图片描述
根据上图,下面两个式子成立:
sk(t+1)=lUwklyl(t)+lIwklxnetl(t)=lUIwklxl(t)......(2)
yk(t+1)=fk(sk(t+1))......(3)

显然, yk(τ+1),yk(τ+2),...,yk(t) 可以表示成 s(τ+1) 的函数,因此,
F=F(y(t),y(t+1),...,yk(τ),s(τ+1))=F
下面对公式进行进一步的推导:
ϵk(τ;F)=Fyk(τ)
=F(y(t),y(t+1),...,yk(τ),s(τ+1))yk(τ)
由复合函数求导法则,上式可进一步变为:
Fy(t)y(t)yk(τ)+Fy(t+1)y(t+1)yk(τ)+...+Fy(τ)y(τ)yk(τ)+Fs(τ+1)s(τ+1)yk(τ)

τ<τ 时,显然 y(τ) y(τ) 无关,故上式的前半部分为0,即:
ϵk(τ;F)=Fy(τ)y(τ)yk(τ)+Fs(τ+1)s(τ+1)yk(τ)

这里:
Fy(τ)=Fy1(τ)Fy2(τ)...Fyk(τ)...Fyn(τ)

y(τ)yk(τ)=y1(τ)yk(τ)y2(τ)yk(τ)...yk(τ)yk(τ)...yn(τ)yk(τ)=00...1...0

Fs(τ+1)=Fs1(τ+1)Fs2(τ+1)...Fsl(τ+1)...Fsn(τ+1)=δ1(τ+1;F)δ2(τ+1;F)...δl(τ+1;F)...δn(τ+1;F)

s(τ+1)yk(τ)=s1(τ+1)yk(τ)s2(τ+1)yk(τ)...sl(τ+1)yk(τ)...sn(τ+1)yk(τ)=w1kw2k...wlk...wnk

代入,上式可以变为:
ϵk(τ;F)=Fy1(τ)Fy2(τ)...Fyk(τ)...Fyn(τ)T00...1...0+δ1(τ+1;F)δ2(τ+1;F)...δl(τ+1;F)...δn(τ+1;F)Tw1kw2k...wlk...wnk=Fyk(τ)+lUwlkδl(τ+1;F)

所以就有:
ϵk(τ;F)=Fyk(τ)+lUwlkδl(τ+1;F)=ek(τ;F)+lUwlkδl(τ+1;F)

因为当 τ=t 时, ϵk(t;F)=ek(t;F) ,所以有:

ϵk(τ;F)=ek(t;F)  if  τ=tek(τ;F)+lUwlkδl(τ+1;F)  if  τ<t

δk(τ;F)=Fsk(τ)=Fyk(τ)yk(τ)sk(τ)=ϵk(τ;F)fk(sk(τ))

进一步推导:
ϵk(τ;F)=(ek(τ;F)+lUwlkδl(τ+1;F))fk(sk(τ))
先做如下定义:
wij :第 j 个输入到第

i
个隐层节点的权重(迭代更新之前),其中 iU,jUI
wij(τ) : τ 时刻第 j 个输入到第

i
个隐层节点的权重(迭代更新之前),其中 τ[t0,t),iU,jUI

Fwij(τ)=Fsi(τ+1)si(τ+1)wij(τ)=δi(τ+1;F)xj(τ)

Fwij=τ=t0t1Fwij(τ)wij(τ)wij=τ=t0t1Fwij(τ)=τ=t0t1δi(τ+1;F)xj(τ)

4.BPTT(Back Propagation Through Time)

4.1 Real-Time BPTT

算法描述:
τ(t0,t],kU ,
ϵk(t)=ek(t),
δk(τ)=fk(sk(τ))ϵk(τ),
ϵk(τ1)=lUwlkδl(τ),
可以看出,算法的公式与BP算法非常相似,算法从t时刻开始,先用等式 ϵk(t)=ek(t) 求出 ϵk(t) ,然后再用后边两个等式继续向后迭代,直到 t0 。这里的第一步也被称为错误注入(injecting error),也说是在t时刻注入了 ek(t)
误差传导
上图描述了Real-Time BPTT算法在每一个时刻t的存储和处理操作。历史缓存每经过一个时刻t,就会增加一层的数据(包括该t时刻所有的输入和输出值)。实线箭头表明了当前的输出值由和上一时刻的输入输出值确定。虚线表示反向传播,计算直到 t0+1 δ 。步骤①为injecting error操作,剩下的步骤为每一步的误差计算。

激活函数通常取logistics函数,此时的 fk(sk(τ))=fk(sk(τ))(1fk(sk(τ)))
最后,权值的梯度通过下式计算:
J(t)wij=τ=t0+1tδi(τ)xj(τ1)

在每一个时刻t,算法的执行流程如下:
(1)将当前网络的状态和当前的输入值添加到历史缓存2
(2)注入当前时刻 t

ek(t)
,然后在时间区间 (t0,t] 上进行反向传播,计算出所有的 ϵk(τ),δk(τ)
(3)计算所有的 J(t)wij ;
(4)根据第(3)步的结果修改权值。

随着时间的增长,算法对历史缓存的需求将是无限的,因此,有时也用BPTT(∞)来表示这个算法,它在理论上的研究价值要远大于实用。接下来,我们将讨论更为实用的近似算法。

4.2 Epochwise BPTT

为了解决Real-Time BPTT对内存的无限制需求,我们采用一种近似的算法,即:Epochwise BPTT。
算法的目标是计算基于 Jtotal(t0,t1) 的梯度(即损失函数 F=Jtotal(t0,t1) ),其步骤跟前边类似。同样的,
τ(t0,t1],kU ,
ϵk(t1)=ek(t1),
δk(τ)=fk(sk(τ))ϵk(τ),
ϵk(τ1)=ek(τ1)+lUwlkδl(τ),

算法从最后的时刻 t1 开始,injecting error ek(t1) ,然后运用后边两个等式,迭代计算 δk(τ),ϵk(τ1) ,直到 τ=t0+1 。此时权值的梯度按下式计算:
Jtotal(t0,t1)wij=τ=t0+1t1δi(τ)xj(τ1)
误差传导
[t0,t1] 中所有的输入输出以及目标值都被存储在历史缓存中。实线表示输出由上一时刻的输入和输出确定,当一次epoch完成后,执行BP操作(虚线箭头)。奇数索引的步骤表示error injection,偶数索引的步骤表示误差( δ )传播。一旦BP操作完成,每个权值的梯度就可以算出来了。

算法的执行流程如下:
(1)执行BP算法,计算所有的 ϵk(τ),δk(τ),τ(t0,t1]
(2)计算所有的 Jtotal(t0,t1)wij
(3)使用(2)的结果更新权值,重复步骤(1)~(3);

5.RTRL(Real-Time Recurrent Learning)

与反向传播的BPTT算法不同的是,RTRL通过前向传播梯度来进行计算。

对任意的 kU,iU,jUI,t[t0,t1] ,定义:
pkij(t)=yk(t)wij
F=J(t) ,有:
J(t)wij=kUek(t)pkij(t)

根据之前的关系等式:
sk(t+1)=lUwklyl(t)+lIwklxnetl(t)=lUIwklxl(t)......(2)
yk(t+1)=fk(sk(t+1))......(3)
可以推出:
pkij(t+1)=yk(t+1)wij=yk(t+1)sk(t+1)sk(t+1)wij=fk(sk(t+1))[lUwklplij(t)+δikxj(t)] 3
此外, t0 时刻的输出为预设值,与连接权值无关,所以有:
pkij(t0)=yk(t0)wij=0
于是,整个计算过程将从 t=t0 开始迭代计算,直到 t=t1
对每一个时刻 t ,计算相应的

yk(t)
以及 J(t)wij

6.Hybrid(FP/BPTT)

Fwij=τ=t0t1Fwij(τ)+τ=tt1Fwij(τ)
等式右边的第一部分可写为:
τ=t0t1Fwij(τ)=τ=t0t1lUFyl(t)yl(t)wij(τ)=lUFyl(t)τ=t0t1yl(t)wij(τ)=lUFyl(t)yl(t)wij=lUϵl(t;F)plij(t)
因此,最初的式子可变为:
Fwij=lUϵl(t;F)plij(t)+τ=tt1δi(τ+1;F)xj(τ)
F=Jtotal(t,t)
Jtotal(t,t)wij=lUϵl(t)plij(t)+τ=tt1δi(τ+1)xj(τ)

首先计算BPTT:

ϵk(τ)=δkr  if  τ=tlUwlkδl(τ+1)  if  τ<t

然后,使用上边的计算结果执行:
prij(t)=lUϵl(t)plij(t)+τ=tt1δl(τ+1)xj(τ)
误差传递
上图是FP/BPTT(h)算法的简单描述。可以看到,算法包含两个连续的误差计算过程。一个在时刻 t ,另一个在时刻

t+h
.从时刻 th 直到时刻 t 的输入、输出和目标值都存储在历史缓存中。

7.参考文献

1.Gradient-Based Learning Algorithms for Recurrent Networks and Their Computational Complexity.Ronald J. Williams,David Zipser




  1. F:F{yk(τ)|kU,τ(t,t]}

    F=F(yk(t+1),yk(t+2),...,yk(τ),...,yk(t))
    这地方稍微深入说明一下引入变量 yk(τ) 的原因:
    假设有函数 f(x,y)=x+2y ,同时, y,x 满足: y=x2
    对f(x,y)求偏导数: fx=(x+2y)x
    这个地方出现了两个 x (分别在分式的上下边),这两个x虽然相等,但含义其实并不相同。下边的

    x
    是自变量,上边的 x 其实可以看做自变量的一个函数,不妨令

    t=x
    ,于是有如下关系式:

    {
    x(t)=ty(t)=t2


    于是 f(x,y)=f(x(t),y(t))
    fx=f(x(t),y(t))t
    由复合函数求导法则,上式又可变为:
    f(x(t),y(t))x(t)x(t)t+f(x(t),y(t))y(t)y(t)t
    由于x(t),y(t)是t的单变量函数,有:
    x(t)t=dx(t)dt
    y(t)t=dy(t)dt
    所以有:
    fx=f(x(t),y(t))x(t)dx(t)dt+f(x(t),y(t))y(t)dy(t)dt
    类比函数 F=F(y(t+1),y(t+2),...,yk(τ),...,y(t)) ,对其求关于 yk(τ) 的偏导数显然也存在符号混淆的问题,所以,有必要引入符号
    yk(τ)=yk(τ)(yk(τ))=yk(τ)
    yk(τ)(yk(τ)) 后边的括号表示 yk(τ) yk(τ) 的函数。变量符号 yk(τ) 的意义与上例中 x(t) 的意义一样。

    • 历史缓存(History buffer)中存储了整个网络从 t0 时刻开始的输入和激活信息。
    • δik 是克罗内克函数(Kronecker delta)
      函数定义:
      δik={
      1  if  i=k0  if  ik

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152303.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • java小型图书馆管理系统

    java小型图书馆管理系统根据需求,建立了一个BookMgr类,该类为实现小型图书馆的各个需求。为了和用户有一个良好的交互,根据需求且满足要求中的隐藏条件,先命名了交互的菜单函数printMenu1(),代码如下:publicvoidprintMenu1(){          System.out.println(“欢迎使用图书馆管理系统”);          Syst

    2022年7月8日
    21
  • JS几种数组遍历方式总结

    JS几种数组遍历方式总结JS数组遍历的几种方式JS数组遍历,基本就是for,forin,foreach,forof,map等等一些方法,以下介绍几种本文分析用到的数组遍历方式以及进行性能分析对比第一种:普通for循环代码如下:for(j=0;j&lt;arr.length;j++){}简要说明:最简单的一种,也是使用频率最高的一种,虽然性能不弱,但仍有优化空间第二种:优化版for循环代码如下…

    2022年7月12日
    16
  • oracle修改数据库用户名密码,怎样修改oracle数据库的用户名密码[通俗易懂]

    oracle修改数据库用户名密码,怎样修改oracle数据库的用户名密码[通俗易懂]对于不经常使用数据库的同学们来说,忘记用户名密码是很常见的一件事。下面就让学习啦小编给大家说说怎样修改oracle数据库的用户名密码吧。修改oracle数据库用户名密码的方法进入cmd命令界面(快捷键是win+R)。修改管理员用户密码(一):在命令界面输入sqlplus“/assysdba”即可以管理员身份链接成功。修改管理员用户密码(二):在SQL命令界面输入alterusersyste…

    2022年7月28日
    3
  • css实现导航菜单下拉效果「建议收藏」

    css实现导航菜单下拉效果「建议收藏」通过css也可以实现简单的导航栏效果,一些不会写js的下伙伴不用担心了。先上HTML部分&lt;nav&gt;&lt;ulclass="level"&gt;&lt;li&gt;&lt;ahref=""&gt;首页&lt;/a&gt;&lt;/li&gt;&lt;li&gt;

    2022年7月26日
    11
  • 关于FindWindow函数「建议收藏」

    关于FindWindow函数「建议收藏」在调用FindWindow函数的时候,应该第一个参数为空,第二个参数为窗口的标题名。classname是窗口在创建时的注册名称,不是源代码的类名,通常可以不指定,除非确切地知道。另外,vs自带一个spy++的工具,可以探查当前所有窗口的信息,包括注册类名。FindWindow这个函数检索处理顶级窗口的类名和窗口名称匹配指定的字符串。这个函数有两个参数,第一个是要找的窗口的类,第二个是要找的窗口的…

    2022年8月13日
    3
  • Binwalk工具的详细使用说明

    https://blog.csdn.net/QQ1084283172/article/details/66971242一、binwalk工具的基本用法介绍1.获取帮助信息$binwalk-h#或者$binwalk–help2.固件分析扫描$binwalkfirmware.bin#或者$binwalkfirmware.bin|hea…

    2022年4月4日
    34

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号