使用Keras实现 基于注意力机制(Attention)的 LSTM 时间序列预测

时隔半年多,毕设男孩终于重操旧业,回到了LSTM进行时间序列预测和异常检测的路上。如果有阅读过我之前的博客,可以发现使用LSTM作单类的时间序列异常检测也是基于对于时间序列的预测进行登堂入室LSTM:使用LSTM进行简单的时间序列异常检测本次我们要进行的是使用注意力机制+LSTM进行时间序列预测,项目地址为KerasAttentionMechanism首先我们把它git…

大家好,又见面了,我是你们的朋友全栈君。

时隔半年多,毕设男孩终于重操旧业,回到了 LSTM进行时间序列预测和异常检测的路上。

如果有阅读过我之前的博客,可以发现使用 LSTM作单类的时间序列异常检测也是基于对于时间序列的预测进行 登堂入室LSTM:使用LSTM进行简单的时间序列异常检测

本次我们要进行的是 使用 注意力机制 + LSTM 进行时间序列预测,项目地址为Keras Attention Mechanism

对于时间步的注意力机制

首先我们把它git clone 到本地,然后配置好所需环境 笔者的 tensorflow版本为1.6.0 Keras 版本为 2.0.2
打开文件夹,我们主要需要的是attention_lstm.py 以及 attention_utils.py 脚本

项目中生成数据的函数为

def get_data_recurrent(n, time_steps, input_dim, attention_column=10):
    """ Data generation. x is purely random except that it's first value equals the target y. In practice, the network should learn that the target = x[attention_column]. Therefore, most of its attention should be focused on the value addressed by attention_column. :param n: the number of samples to retrieve. :param time_steps: the number of time steps of your series. :param input_dim: the number of dimensions of each element in the series. :param attention_column: the column linked to the target. Everything else is purely random. :return: x: model inputs, y: model targets """
    x = np.random.standard_normal(size=(n, time_steps, input_dim))
    y = np.random.randint(low=0, high=2, size=(n, 1))
    x[:, attention_column, :] = np.tile(y[:], (1, input_dim))
    return x, y

默认的 n = 30000, input_dim = 2 ,timesteps = 20。生成的数据为:

shape
x 30000 x 20 x 2
y 30000 x 1

其中 x 的第11个 timestep 两维的数据 与y相同,其他timestep 维的数据为随机数。

所以当我们使用这样的数据去进行 注意力机制 LSTM 的训练,我们希望得到的结果是 注意力层 主要关注第11个timestep 而对其他timestep 的关注度较低。

直接运行 attention_lstm.py 脚本
此时的网络结构为:
在这里插入图片描述
可以看到是在 LSTM 层之后使用了注意力机制

最后会汇总画一张图
在这里插入图片描述
可以看到 可以看到注意力的权重主要汇总在了第11个timestep,说明注意力机制很成功

对于维的注意力机制

上述的例子 是将注意力机制使用在了 timestep 上,决定哪个时间步对于结果的影响较大。
而如果我们想将 注意力机制使用在维上呢? 比如使用多维去预测一维的数据,我们想使用注意力机制
决定哪些维对于预测维起关键作用。

比较简单的方法就是将输入数据 reshape 一下 将timesteps 与 input_dim 维对换 再运行就可以了,因为本代码的设置就是对 输入的第2维加入注意力机制.

进阶的方法就是 自写一下 attention_3d_block 函数:

def attention_3d_block(inputs):
    # inputs.shape = (batch_size, time_steps, input_dim)
    input_dim = int(inputs.shape[2])
    a = inputs
    #a = Permute((2, 1))(inputs)
    #a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what.
    a = Dense(input_dim, activation='softmax')(a)
    if SINGLE_ATTENTION_VECTOR:
        a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
        a = RepeatVector(input_dim)(a)
    a_probs = Permute((1, 2), name='attention_vec')(a)
    #a_probs = a
    output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
    return output_attention_mul

其实严格来讲我们所做的改变不多,作者使用了 Permute层对于数据进行了 第2和第3维的对换,我们则没有进行对换操作。

接下来 再在attention_utils.py 脚本中写一个产生数据集的新函数:

def get_data_recurrent2(n, time_steps, input_dim, attention_dim=5):
    """ 假设 input_dim = 10 time_steps = 6 产生一个 x 6 x 10 的数据 其中每步的第 6 维 与 y相同 """
    x = np.random.standard_normal(size=(n, time_steps, input_dim))
    y = np.random.randint(low=0, high=2, size=(n, 1))
    x[:,:,attention_dim] =  np.tile(y[:], (1, time_steps))

    return x,y

试着产生一组数据 get_data_recurrent2(1,6,10)
在这里插入图片描述

然后我们稍微改动一下main函数进行新的训练。迭代十次后结果为:
在这里插入图片描述
可以看到,第6维的权重比较大。
如果我们对于timesteps的注意力画一个汇总图,即改动一下

  attention_vector = np.mean(get_activations(m, testing_inputs_1,print_shape_only=False,layer_name='attention_vec')[0], axis=2).squeeze()

可以看到对于timesteps的注意力是相同的(其实如果对于开头时间步的注意力机制,对输入维的注意力画一个汇总图,也是相同的)
在这里插入图片描述

对于时间步和输入维的注意力机制

待补充

注:参考 keras-attention-mechanism
以及 Keras中文文档

代码已上传到我的github

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/125920.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • qq怎样防盗 qq密码如何防盗[通俗易懂]

    qq怎样防盗 qq密码如何防盗[通俗易懂]1.去腾讯申请密码保护,这样如果密码被激活成功教程或自己忘记了还可以利用密码保护功能取回来。2.QQ密码的位数一定要超过8位,而且最好包含数字、字母和特殊符号,否则以现代计算机的超强计算能力,要想暴力激活成功教程你的QQ密码简直是易如反掌。3.不要在QQ中填入真实的年龄、E-mail等敏感消息,更不能告诉任何人,小心行得万年船。4.不要随意运行别人发给你的文件,即便那些看起来很诱人的…

    2022年7月20日
    15
  • SSM项目部署到阿里云服务器。只需要五个步骤。

    SSM项目部署到阿里云服务器。只需要五个步骤。最近在看很多ssm项目部署到阿里云的教程:踩了很多坑,所以希望大家部署时候有所借鉴吧。有什么不懂可以联系qq交流:980631161.主要分为五个步骤:1.购买服务器2下载xshell和Xftp63.在服务器上安装jdk,mysql,tomcat。4.数据库准备数据5.maven项目生成war文件。1.购买服务器在阿里云购买一个ESC服务器网址是:https://www….

    2022年6月20日
    58
  • c语言opencv读取图像_matlab读取一幅图像并显示

    c语言opencv读取图像_matlab读取一幅图像并显示函数cv2.imread()用于从指定的文件读取图像OpenCV完整例程200篇01.图像的读取(cv2.imread)02.图像的保存(cv2.imwrite)03.图像的显示(cv2.imshow)07.图像的创建(np.zeros)08.图像的复制(np.copy)09.图像的裁剪(cv2.selectROI)10.图像的拼接(np.hstack)……………

    2022年8月31日
    3
  • p2p在线直播流(何为流媒体)

    看到网上一些吹牛P2P低延时的文章,觉得不是很靠谱,抽空调研了一下这个问题。P2P低延时的几个方向:   方法一:通过直接采集并编码多媒体帧,将多媒体帧切分成1KB大小的数据颗粒,采用push策略的进行小包传输,提高传输效率,减小传输延时;          具体参见:http://www.google.com/patents/CN101945129A?cl

    2022年4月10日
    71
  • pgrouting 路径规划_路径分析是什么意思

    pgrouting 路径规划_路径分析是什么意思一.技术背景,相关技术介绍   PgRouting是基于开源空间数据库PostGIS用于网络分析的扩展模块,最初它被称作pgDijkstra,因为它只是利用Dijkstra算法实现最短路径搜索,之后慢慢添加了其他的路径分析算法,如A算法,双向A算法,Dijkstra算法,双向Dijkstra算法,tsp货郎担算法等,然后被更名为pgRouting[1]。该扩展库依托PostGIS自身的g…

    2022年8月24日
    5
  • BigDecimal 加减乘除[通俗易懂]

    在java里面,int的最大值是:2147483647,现在如果想用比这个数大怎么办?换句话说,就是数值较大,这时候就用到了BigDecimal 下载整理了一下BigDecimal的加减乘除。。 BigDecimalbignum1=newBigDecimal(“10”); BigDecimalbignum2=newBigDecimal(

    2022年4月14日
    77

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号