基于keras的双层LSTM网络和双向LSTM网络

基于keras的双层LSTM网络和双向LSTM网络1前言基于keras的双层LSTM网络和双向LSTM网络中,都会用到LSTM层,主要参数如下:LSTM(units,input_shape,return_sequences=False)units:隐藏层神经元个数 input_shape=(time_step,input_feature):time_step是序列递归的步数,input_feature是输入特征维数 re…

大家好,又见面了,我是你们的朋友全栈君。

1 前言

基于keras的双层LSTM网络双向LSTM网络中,都会用到 LSTM层,主要参数如下:

LSTM(units,input_shape,return_sequences=False)
  • units:隐藏层神经元个数
  • input_shape=(time_step, input_feature):time_step是序列递归的步数,input_feature是输入特征维数
  • return_sequences: 取值为True,表示每个时间步的值都返回;取值为False,表示只返回最后一个时间步的取值

本文以MNIST手写数字分类为例,讲解双层LSTM网络和双向LSTM网络的实现。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

基于keras的双层LSTM网络和双向LSTM网络

代码资源见–> 双隐层LSTM和双向LSTM

2 双层LSTM网络

基于keras的双层LSTM网络和双向LSTM网络
双层LSTM网络结构

 DoubleLSTM.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.layers import Dense,LSTM

#载入数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
    valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
    test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y

#双层LSTM模型
def DoubleLSTM(train_x,train_y,valid_x,valid_y,test_x,test_y):
    #创建模型
    model=Sequential()
    model.add(LSTM(64,input_shape=(28,28),return_sequences=True))  #返回所有节点的输出
    model.add(LSTM(32,return_sequences=False))  #返回最后一个节点的输出
    model.add(Dense(10,activation='softmax'))
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=15,verbose=2,validation_data=(valid_x,valid_y))
    #评估模型
    pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
    print('test_loss:',pre[0],'- test_acc:',pre[1])
   
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DoubleLSTM(train_x,train_y,valid_x,valid_y,test_x,test_y)

每层网络输出尺寸:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_5 (LSTM)                (None, 28, 64)            23808     
_________________________________________________________________
lstm_6 (LSTM)                (None, 32)                12416     
_________________________________________________________________
dense_5 (Dense)              (None, 10)                330       
=================================================================
Total params: 36,554
Trainable params: 36,554
Non-trainable params: 0

由于第一个LSTM层设置了 return_sequences=True,每个节点的输出值都会返回,因此输出尺寸为 (None, 28, 64) 

由于第二个LSTM层设置了 return_sequences=False,只有最后一个节点的输出值会返回,因此输出尺寸为 (None, 32) 

 训练结果:

Epoch 13/15
 - 17s - loss: 0.0684 - acc: 0.9796 - val_loss: 0.0723 - val_acc: 0.9792
Epoch 14/15
 - 18s - loss: 0.0633 - acc: 0.9811 - val_loss: 0.0659 - val_acc: 0.9822
Epoch 15/15
 - 17s - loss: 0.0597 - acc: 0.9821 - val_loss: 0.0670 - val_acc: 0.9812
test_loss: 0.0714278114028275 - test_acc: 0.9789000034332276

3 双向LSTM网络

基于keras的双层LSTM网络和双向LSTM网络
双向LSTM网络结构
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.layers import Dense,LSTM,Bidirectional

#载入数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
    valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
    test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y

#双向LSTM模型
def BiLSTM(train_x,train_y,valid_x,valid_y,test_x,test_y):
    #创建模型
    model=Sequential()
    lstm=LSTM(64,input_shape=(28,28),return_sequences=False)  #返回最后一个节点的输出
    model.add(Bidirectional(lstm))  #双向LSTM
    model.add(Dense(10,activation='softmax'))
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=15,verbose=2,validation_data=(valid_x,valid_y))
    #查看网络结构
    model.summary()
    #评估模型
    pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
    print('test_loss:',pre[0],'- test_acc:',pre[1])
   
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
BiLSTM(train_x,train_y,valid_x,valid_y,test_x,test_y)

 每层网络输出尺寸:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional_5 (Bidirection (None, 128)               47616     
_________________________________________________________________
dense_6 (Dense)              (None, 10)                1290      
=================================================================
Total params: 48,906
Trainable params: 48,906
Non-trainable params: 0

由于LSTM层设置了 return_sequences=False,只有最后一个节点的输出值会返回,每层LSTM返回64维向量,两层合并共128维,因此输出尺寸为 (None, 128) 

训练结果: 

Epoch 13/15
 - 22s - loss: 0.0512 - acc: 0.9839 - val_loss: 0.0632 - val_acc: 0.9790
Epoch 14/15
 - 22s - loss: 0.0453 - acc: 0.9865 - val_loss: 0.0534 - val_acc: 0.9832
Epoch 15/15
 - 22s - loss: 0.0418 - acc: 0.9869 - val_loss: 0.0527 - val_acc: 0.9830
test_loss: 0.06457789749838412 - test_acc: 0.9795000076293945

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/149940.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 矩阵范数与向量范数关系_矩阵范数的定义

    矩阵范数与向量范数关系_矩阵范数的定义范数是距离在向量和矩阵上的推广,在研究收敛性、判断矩阵非奇异等方面有广泛应用。本节包括以下内容:(1)向量范数;(2)矩阵范数;(3)从属范数;(4)谱半径;(5)矩阵的非奇异条件。1向量范数从向量到实数的映射/函数。定义(1)条件:非负性、齐次性、三角不等式(∥x+y∥≤∥x∥+∥y∥\|x+y\|\leq\|x\|+\|y\|)。

    2026年1月24日
    6
  • UITableView是不会响应touchesBegan:方法的

    UITableView是不会响应touchesBegan:方法的2019独角兽企业重金招聘Python工程师标准>>>…

    2022年7月25日
    17
  • pycharm的git_pycharm版本控制

    pycharm的git_pycharm版本控制1、createpatchcreatepatch打补丁,当连接不上git服务器时,可以先本地打补丁,生成一个文件,里面记录了文件变更信息,后面可以随时提交至git服务器2、checkoutrevisioncheckoutrevision检出版本,可以回退到任意版本,右边会显示当前检出版本与上一版本的变化3、newbranchnewbranch建立新的分…

    2022年8月29日
    5
  • 安装python时出现的错误0x80072efd及0x80072f7d的解决方法

    安装python时出现的错误0x80072efd及0x80072f7d的解决方法0x80072efd:是下载不了dubuggingsymbols和debugbinaries的问题。要翻墙。或者把2个Download的安装选项取消,就可以完成了。0x80072f7d:修改了EXE文件名称安装成功感觉问题解决的十分不靠谱,感谢https://blog.csdn.net/quantum7/article/details/81738839,脑残丞相的提醒,他安装时也…

    2025年7月31日
    3
  • Python多分类问题pr曲线绘制(含代码)

    Python多分类问题pr曲线绘制(含代码)研究了三天的多分类 pr 曲线问题终于在昨天晚上凌晨一点绘制成功了 现将所学所感记录一下 一来怕自己会忘可以温故一下 二来希望能给同样有疑惑的铁子们一些启迪 下图为我画的 pr 曲线 因为准确度超过了 97 所以曲线很饱和 首先了解一下二分类中的 pr 曲线是怎么画的 p 是 precition 是查准率 也是我们常用到的准确率 r 是 recall 是查全率 也叫召回率 上图为测试结果的混淆矩阵 表示一个数据集上的所有测试结果 其中竖列均为测试结果 即分类器预测概率大于 0 5 为正类 小于 0

    2025年11月24日
    4
  • Java和Java大数据有什么区别?

    Java和Java大数据有什么区别?单单提起java或者大数据,很多人对此都一目了然,但对于Java大数据这样一个新鲜名词,多少有些疑惑。那java和java大数据学习的内容是一样的吗?两者有什么区别呢?今天就从java和java大数据的以下方面谈谈两者的区别。Java和Java大数据有什么区别Java和大数据的关系:java是计算机的一门编程语言;可以用来做很多工作,大数据开发属于其中一种;大数据…

    2022年5月25日
    46

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号