双向 LSTM

双向 LSTM本文结构:为什么用双向LSTM什么是双向LSTM例子为什么用双向LSTM?单向的RNN,是根据前面的信息推出后面的,但有时候只看前面的词是不够的,例如,我今天不舒服,我打算__一天。只根据‘不舒服‘,可能推出我打算‘去医院‘,‘睡觉‘,‘请假‘等等,但如果加上后面的‘一天‘,能选择的范围就变小了,‘去医院‘这种就不能选了,而‘请假‘‘休息‘之类的被选择概率就会更大。什么是双向L

大家好,又见面了,我是你们的朋友全栈君。

本文结构:

  • 为什么用双向 LSTM
  • 什么是双向 LSTM
  • 例子

为什么用双向 LSTM?

单向的 RNN,是根据前面的信息推出后面的,但有时候只看前面的词是不够的,
例如,

我今天不舒服,我打算__一天。

只根据‘不舒服‘,可能推出我打算‘去医院‘,‘睡觉‘,‘请假‘等等,但如果加上后面的‘一天‘,能选择的范围就变小了,‘去医院‘这种就不能选了,而‘请假‘‘休息‘之类的被选择概率就会更大。


什么是双向 LSTM?

双向卷积神经网络的隐藏层要保存两个值, A 参与正向计算, A’ 参与反向计算。
最终的输出值 y 取决于 A 和 A’:

双向 LSTM

即正向计算时,隐藏层的 s_t 与 s_t-1 有关;反向计算时,隐藏层的 s_t 与 s_t+1 有关:

双向 LSTM

双向 LSTM

在某些任务中,双向的 lstm 要比单向的 lstm 的表现要好:

双向 LSTM


例子

下面是一个 keras 实现的 双向LSTM 应用的小例子,任务是对序列进行分类,
例如如下 10 个随机数:

0.63144003 0.29414551 0.91587952 0.95189228 0.32195638 0.60742236 0.83895793 0.18023048 0.84762691 0.29165514

累加值超过设定好的阈值时可标记为 1,否则为 0,例如阈值为 2.5,则上述输入的结果为:

0 0 0 1 1 1 1 1 1 1

和单向 LSTM 的区别是用到 Bidirectional:
model.add(Bidirectional(LSTM(20, return_sequences=True), input_shape=(n_timesteps, 1)))

from random import random
from numpy import array
from numpy import cumsum
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import TimeDistributed
from keras.layers import Bidirectional

# create a sequence classification instance
def get_sequence(n_timesteps):
    # create a sequence of random numbers in [0,1]
    X = array([random() for _ in range(n_timesteps)])
    # calculate cut-off value to change class values
    limit = n_timesteps/4.0
    # determine the class outcome for each item in cumulative sequence
    y = array([0 if x < limit else 1 for x in cumsum(X)])
    # reshape input and output data to be suitable for LSTMs
    X = X.reshape(1, n_timesteps, 1)
    y = y.reshape(1, n_timesteps, 1)
    return X, y

# define problem properties
n_timesteps = 10

# define LSTM
model = Sequential()
model.add(Bidirectional(LSTM(20, return_sequences=True), input_shape=(n_timesteps, 1)))
model.add(TimeDistributed(Dense(1, activation='sigmoid')))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc'])

# train LSTM
for epoch in range(1000):
    # generate new random sequence
    X,y = get_sequence(n_timesteps)
    # fit model for one epoch on this sequence
    model.fit(X, y, epochs=1, batch_size=1, verbose=2)

# evaluate LSTM
X,y = get_sequence(n_timesteps)
yhat = model.predict_classes(X, verbose=0)
for i in range(n_timesteps):
    print('Expected:', y[0, i], 'Predicted', yhat[0, i])

学习资料:
https://zybuluo.com/hanbingtao/note/541458
https://maxwell.ict.griffith.edu.au/spl/publications/papers/ieeesp97_schuster.pdf
http://machinelearningmastery.com/develop-bidirectional-lstm-sequence-classification-python-keras/


推荐阅读
历史技术博文链接汇总
也许可以找到你想要的:
[入门问题][TensorFlow][深度学习][强化学习][神经网络][机器学习][自然语言处理][聊天机器人]

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/147926.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 具体说明 Flume介绍、安装和配置

    具体说明 Flume介绍、安装和配置

    2022年1月6日
    48
  • python标识符在命名时有哪些规则_php标识符的命名规则

    python标识符在命名时有哪些规则_php标识符的命名规则在Python中,一切都是对象,包括常量数据类型,如整数数据类型(1,2,3…),字符串数据类型(“ABC”)。想要使用这些对象,就要使用它的对象引用。赋值操作符,实际上是使得对象引用对内存中存放数据的对象进行引用。那什么是标识符?标识符是对对象引用起的一个名字。有效的Python标识符规则:1.长度任意长;2.标识符不能与关键字同名;3.在2.x版本的Python中,标识符以ASCII的字母…

    2025年9月23日
    9
  • Springcloud分布式事务_分布式事务解决方案框架

    Springcloud分布式事务_分布式事务解决方案框架publicAjaxMessageEntityuserWithdraw(@RequestBodyTPayInfotPayInfo,HttpServletRequestrequest){if(参数校验){//参数校验没有通过,直接返回参数校验错误}if(通过redis做并发控制,使同时提现人数不能…

    2022年4月19日
    58
  • 微信公众平台 获取用户openid

    微信公众平台 获取用户openid今天做微信公众号获取用户的openid,圆满成功,特此来一发。 第一步:理解逻辑。 1:获取openid的逻辑获得微信的openid,需要先访问微信提供的一个网址:这个网址名为url1,下面有赋值。通过这个网址,微信用来识别appid信息,在这个网址中,有一个属性redirect_uri,是微识别完appid后,进行跳转的操作,可以是网页,也可以是servlet,我这里用的是…

    2022年6月26日
    88
  • nslookup两种错误解决方法

    nslookup两种错误解决方法

    2021年8月14日
    346
  • sqlserver 视图创建索引_Oracle创建索引

    sqlserver 视图创建索引_Oracle创建索引一、索引1、添加索引createindex索引对象名on索引对应表名(表内索引对象字段名);例:需创建包含userid属性的userinfo表。createindexuseridonsystem.userinfo(userid);2、删除索引dropindex索引对象名;例:dropindexuserid;二、视图(并不是真实存在的一张表)1、创建视图createview视图名(学号,姓名,科目,成绩)asselect对应在表格中的字段名from涉

    2025年9月27日
    6

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号