seq2seq模型是什么_seq2seq原理

seq2seq模型是什么_seq2seq原理1seq2seq模型简介seq2seq模型是一种基于【encoder-decoder】(编码器-解码器)框架的神经网络模型,广泛应用于自然语言翻译、人机对话等领域。目前,【seq2seq+attention】(注意力机制)已被学者拓展到各个领域。seq2seq于2014年被提出,注意力机制于2015年被提出,两者于2017年进入火热融合和拓展阶段。通常,编码器和解码器都是一个LSTM网络…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全家桶1年46,售后保障稳定

1 seq2seq模型简介

seq2seq 模型是一种基于【 Encoder-Decoder】(编码器-解码器)框架的神经网络模型,广泛应用于自然语言翻译、人机对话等领域。目前,【seq2seq+attention】(注意力机制)已被学者拓展到各个领域。seq2seq于2014年被提出,注意力机制于2015年被提出,两者于2017年进入疯狂融合和拓展阶段。

1.1 seq2seq原理

通常,编码器和解码器可以是一层或多层 RNN、LSTM、GRU 等神经网络。为方便讲述原理,本文以 RNN 为例。seq2seq模型的输入和输出长度可以不一样。如图,Encoder 通过编码输入序列获得语义编码 C,Decoder 通过解码 C 获得输出序列。

seq2seq模型是什么_seq2seq原理
seq2seq网络结构图

 Encoder

seq2seq模型是什么_seq2seq原理

Decoder

seq2seq模型是什么_seq2seq原理

说明:xi、hi、C、h’i 都是列向量 

1.2 seq2seq+attention原理

普通的 seq2seq 模型中,Decoder 每步的输入都是相同的语义编码 C,没有针对性的学习,导致解码效果不佳。添加注意力机制后,使得每步输入的语义编码不一样,捕获的信息更有针对性,解码效果更佳。

seq2seq模型是什么_seq2seq原理
seq2seq+attention网络结构图

Encoder

seq2seq模型是什么_seq2seq原理

Decoder

\large h=\{h_1,h_2,...,h_n\}

seq2seq模型是什么_seq2seq原理

(1)标准 attention

seq2seq模型是什么_seq2seq原理

其中 ,v、W、U 都是待学习参数,v 为列向量,W、U 为矩阵

(2)attention 扩展

扩展的 attention 机制有3种方法,如下。其中,v、W 都是待学习参数,v 为列向量,W为矩阵。相较于标准的 attention,待学习的参数明显减少了些。

seq2seq模型是什么_seq2seq原理

说明:xi、hi、Ci、h’i、wi 、ei 都是列向量,h 是矩阵 

2 安装seq2seq

若下载比较慢,可以先通过【码云】导入,再在码云上下载,如下:

seq2seq模型是什么_seq2seq原理

本文以MNIST手写数字分类为例,讲解 seq2seq 模型和 AtttionSeq2seq 模型的实现。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下: 

seq2seq模型是什么_seq2seq原理

代码资源见–>seq2seq模型和基于注意力机制的seq2seq模型 

3 SimpleSeq2Seq

SimpleSeq2Seq(input_length, input_dim, hidden_dim, output_length, output_dim, depth=1)

Jetbrains全家桶1年46,售后保障稳定

  •  input_length:输入序列长度
  • input_dim:输入序列维度
  • output_length:输出序列长度
  • output_dim:输出序列维度
  • depth:Encoder 和 Decoder 的深度,取值可以为整数或元组。如 depth=3,表示 Encoder 和 Decoder 都有 3 层;depth=(3, 4) 表示 Encoder 有3层和 Decoder 有4层

SimpleSeq2Seq.py

from tensorflow.examples.tutorials.mnist import input_data
from seq2seq.models import SimpleSeq2Seq
from keras.models import Sequential
from keras.layers import Dense,Flatten

#载入数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
    valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
    test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y

#SimpleSeq2Seq模型
def seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y):
    #创建模型
    model=Sequential()
    seq=SimpleSeq2Seq(input_dim=28,hidden_dim=32,output_length=10,output_dim=10)
    model.add(seq)
    model.add(Flatten())  #扁平化
    model.add(Dense(10,activation='softmax'))
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=25,verbose=2,validation_data=(valid_x,valid_y))    
    #评估模型
    pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
    print('test_loss:',pre[0],'- test_acc:',pre[1])
   
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
model_14 (Model)             (None, 10, 10)            10368     
_________________________________________________________________
flatten_1 (Flatten)          (None, 100)               0         
_________________________________________________________________
dense_23 (Dense)             (None, 10)                1010      
=================================================================
Total params: 11,378
Trainable params: 11,378
Non-trainable params: 0

网络训练结果:

Epoch 23/25
 - 17s - loss: 0.1521 - acc: 0.9563 - val_loss: 0.1400 - val_acc: 0.9598
Epoch 24/25
 - 17s - loss: 0.1545 - acc: 0.9553 - val_loss: 0.1541 - val_acc: 0.9536
Epoch 25/25
 - 17s - loss: 0.1414 - acc: 0.9594 - val_loss: 0.1357 - val_acc: 0.9624
test_loss: 0.14208583533763885 - test_acc: 0.9567999958992004

4 AttentionSeq2Seq

AttentionSeq2Seq(input_length, input_dim, hidden_dim, output_length, output_dim, depth=1)
  •  input_length:输入序列长度
  • input_dim:输入序列维度
  • output_length:输出序列长度
  • output_dim:输出序列维度
  • depth:Encoder 和 Decoder 的深度,取值可以为整数或元组。如 depth=3,表示 Encoder 和 Decoder 都有 3 层;depth=(3, 4) 表示 Encoder 有3层和 Decoder 有4层

AttentionSeq2Seq.py

from tensorflow.examples.tutorials.mnist import input_data
from seq2seq.models import AttentionSeq2Seq
from keras.models import Sequential
from keras.layers import Dense,Flatten

#载入数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
    valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
    test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y

#AttentionSeq2Seq模型
def seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y):
    #创建模型
    model=Sequential()
    seq=AttentionSeq2Seq(input_length=28,input_dim=28,hidden_dim=32,output_length=10,output_dim=10)
    model.add(seq)
    model.add(Flatten())  #扁平化
    model.add(Dense(10,activation='softmax'))
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=25,verbose=2,validation_data=(valid_x,valid_y))    
    #评估模型
    pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
    print('test_loss:',pre[0],'- test_acc:',pre[1])
   
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
seq2Seq(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
model_102 (Model)            (None, 10, 10)            24459     
_________________________________________________________________
flatten_6 (Flatten)          (None, 100)               0         
_________________________________________________________________
dense_176 (Dense)            (None, 10)                1010      
=================================================================
Total params: 25,469
Trainable params: 25,469
Non-trainable params: 0

网络训练结果:

Epoch 23/25
 - 36s - loss: 0.0533 - acc: 0.9835 - val_loss: 0.0719 - val_acc: 0.9794
Epoch 24/25
 - 37s - loss: 0.0511 - acc: 0.9843 - val_loss: 0.0689 - val_acc: 0.9800
Epoch 25/25
 - 37s - loss: 0.0473 - acc: 0.9860 - val_loss: 0.0700 - val_acc: 0.9802
test_loss: 0.06055343023035675 - test_acc: 0.9825000047683716

SimpleSeq2Seq 模型和 AttentionSeq2Seq 模型的预测精度分别为 0.9568、0.9825,说明添加注意力机制后,预测精度有了明显的提示。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/234854.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • “备份集中的数据库备份与现有的数据库不同”解决方法「建议收藏」

    “备份集中的数据库备份与现有的数据库不同”解决方法「建议收藏」原文发布时间为:2010-09-16——来源于本人的百度文章[由搬家工具导入]最主要就是要在“选项”中选择“覆盖现有数据库”,否则就会出现“备份集中的数据库备份与现有的数据库”的问题。以前一直使用SQLServer2000,现在跟潮流都这么紧,而且制定要求使用SQLServer2005,就在现在的项目中使用它了。对于SQLServer2005,有几个地方是要注意的,比方在还原数据库…

    2022年4月30日
    48
  • System.err.println()和System.out.println()区别

    System.err.println()和System.out.println()区别看了些资料总结下:1.JDK文档对两者的解释:out:“标准”输出流。此流已打开并准备接受输出数据。通常,此流对应于显示器输出或者由主机环境或用户指定的另一个输出目标。err:“标准”错误输出流。此流已打开并准备接受输出数据。通常,此流对应于显示器输出或者由主机环境或用户指定的另一个输出目标。按照惯例,此输出流用于显示错误消息,或者显示那些即使用户输出流(变量 out 的值)已经重定向…

    2022年6月13日
    28
  • pandas无法打开.xlsx文件,xlrd.biffh.XLRDError: Excel xlsx file; not supported

    pandas无法打开.xlsx文件,xlrd.biffh.XLRDError: Excel xlsx file; not supported原因是最近xlrd更新到了2.0.1版本,只支持.xls文件。所以pandas.read_excel(‘xxx.xlsx’)会报错。可以安装旧版xlrd,在cmd中运行:pipuninstallxlrdpipinstallxlrd==1.2.0

    2022年10月20日
    0
  • sql数据库回滚操作_sql回滚语句 rollback

    sql数据库回滚操作_sql回滚语句 rollbackcreatetable testtable(idnvchart(50)primkey,namenvchart(50),remarknvchart(50))select*fromtesttable go BEGINTRY –SQLServer需要显示的定义开始一个事务.BEGINTRANSACTION;–插入2条同样的数据

    2022年8月30日
    1
  • 成本=固定成本+可变成本_可避免固定成本是机会成本吗

    成本=固定成本+可变成本_可避免固定成本是机会成本吗1、固定成本和可变成本根据成本费用与产量的关系可将总成本费用分为:可变成本;是指随着产品产量的增减而成正比例变化的各项费用。固定成本:是指不随产品产量的变化的各项成本费用。半可变(或半固定)成本:有些成本费用属于半可变成本,如不能熄灭的工业炉的燃料费等。工资、营业费用和流动资金利息等也都可能既有可变因素,又有固定因素。必要时需将半可变(或半固定)成进一步分解为可变成本和…

    2022年10月22日
    0
  • PHP递归算法_php递归函数详解

    PHP递归算法_php递归函数详解递归算法的实现方法是有多种的,如通过“静态变量”、“全局变量”、“引用传参”的方式:静态变量的方法:<?phpfunctioncall(){static$i=0;echo$i.”;$i++;if($i<10){call();}}call();输出:012345678…

    2022年8月11日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号