Keras中创建LSTM模型的步骤[通俗易懂]

Keras中创建LSTM模型的步骤[通俗易懂]目录写在前面概述环境1、定义网络2、编译网络3、训练网络4、评估网络5、进行预测一个LSTM示例总结写在前面本文是对The5StepLife-CycleforLongShort-TermMemoryModelsinKeras的翻译,新手博主,边学边记,以便后续温习,或者对他人有所帮助概述深度学习神经网络在Python中很容易使用Keras创建和评估,但您必须遵循严格的模型生命周期。在这篇文章中,您将了解创建、训练和评估Keras中长期记忆(LSTM)循环神经网络的分步生

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

写在前面

本文是对The 5 Step Life-Cycle for Long Short-Term Memory Models in Keras的复现与解读,新手博主,边学边记,以便后续温习,或者对他人有所帮助

概述

深度学习神经网络在 Python 中很容易使用 Keras 创建和评估,但您必须遵循严格的模型生命周期。
在这篇文章中,您将了解创建、训练和评估Keras中长期记忆(LSTM)循环神经网络的分步生命周期,以及如何使用训练有素的模型进行预测。
阅读这篇文章后,您将知道:

  1. 如何定义、编译、拟合和评估 Keras 中的 LSTM;
  2. 如何为回归和分类序列预测问题选择标准默认值。;
  3. 如何将所有连接在一起,在 Keras 开发和运行您的第一个 LSTM 循环神经网络。
    可以参考Long Short-Term Memory Networks With Python,包含了所有示例的教程以及Python源代码文件

环境

本教程假定您安装了 Python SciPy 环境。此示例可以使用 Python 2 或 3。

本教程假定您已使用 TensorFlow 或 Theano 后端安装了 Keras v2.0 或更高版本。

本教程还假定您安装了scikit-learn、pandas、NumPy 和 Matplotlib。

接下来,让我们来看看一个标准时间序列预测问题,我们可以用作此实验的上下文。

1、定义网络

第一步是定义您的网络。

神经网络在 Keras 中定义为一系列图层。这些图层的容器是顺序类。

第一步是创建顺序类的实例。然后,您可以创建图层,并按应连接它们的顺序添加它们。由内存单元组成的LSTM循环层称为LSTM()。通常跟随 LSTM 图层并用于输出预测的完全连接层称为 Dense()。

例如,我们可以通过两个步骤完成操作:

model = Sequential()
model.add(LSTM(2))
model.add(Dense(1))

但是,我们也可以通过创建层数组并传递到序列的构造函数来一步完成。

layers = [LSTM(2), Dense(1)]
model = Sequential(layers)

网络中的第一层必须定义预期输入数。输入必须是三维的,由Samples、Timesteps和Features组成。
Samples:数据中的行
Timesteps:特征的过去观测值
features:数据中的列

假设数据作为 NumPy 数组加载,您可以使用 NumPy 中的 reshape()函数将 2D 数据集转换为 3D 数据集。如果希望列成为一个特征的时间步长,可以使用:

data = data.reshape((data.shape[0], data.shape[1], 1))

如果希望 2D 数据中的列通过一个时间步成为特征,可以使用:

data = data.reshape((data.shape[0], 1, data.shape[1]))

您可以指定input_shape,该参数需要包含时间步长数和特征数的元组。例如,如果我们有两个时间步长和一个特征的单变量时间序列与两个滞后观测值每行,它将指定如下:

model = Sequential()
model.add(LSTM(5, input_shape=(2,1)))
model.add(Dense(1))

LSTM 图层可以通过将它们添加到顺序模型来堆叠。重要的是,在堆叠 LSTM 图层时,我们必须为每个输入输出一个序列而不是单个值,以便后续 LSTM 图层可以具有所需的 3D 输入。我们可以通过将”return_sequences true 来做到这一点。例如:

model = Sequential()
model.add(LSTM(5, input_shape=(2,1), return_sequences=True))
model.add(LSTM(5))
model.add(Dense(1))

将顺序模型视为一个管道,最终输入原始数据,并在另一个数据中显示预测。

这是 Keras 中的有用容器,因为传统上与图层关联的关注点也可以拆分并添加为单独的图层,清楚地显示它们在数据从输入到预测转换中的作用。

例如,可以将从图层中每个神经元转换求和信号的激活函数提取并添加到序列中,作为称为”激活”的图层样对象。

model = Sequential()
model.add(LSTM(5, input_shape=(2,1)))
model.add(Dense(1))
model.add(Activation('sigmoid'))

激活函数的选择对于输出层来说至关重要,因为它将定义预测将采用的格式。

例如,下面是一些常见的预测建模问题类型以及可以在输出层中使用的结构和标准激活函数:

回归:线性激活函数,或”linear”,以及与输出数匹配的神经元数。
二元分类:逻辑激活功能,或”sigmoid”,一个神经元输出层。
多类分类: Softmax激活函数,或”softmax”,每个类值一个输出神经元,假设为一热编码的输出模式。

2、编译网络

一旦我们定义了我们的网络,我们必须编译它。

编译是效率的一步。它将我们定义的简单层序列转换为一系列高效的矩阵转换,其格式旨在根据 Keras 的配置方式在 GPU 或 CPU 上执行。

将编译视为网络的预计算步骤。定义模型后始终需要它。

编译需要指定许多参数,这些参数是专为培训网络而定制的。具体来说,用于训练网络和用于评估优化算法最小化的网络的优化算法。

例如,下面是编译定义的模型并指定随机梯度下降 (sgd) 优化算法和用于回归类型问题的均方误差 (mean_squared_error) 损失函数的示例。

model.compile(optimizer='sgd', loss='mean_squared_error')

或者,可以在作为编译步骤的参数提供之前创建和配置优化器。

algorithm = SGD(lr=0.1, momentum=0.3)
model.compile(optimizer=algorithm, loss='mean_squared_error')

预测建模问题的类型对可以使用的损耗函数的类型施加了约束。

例如,以下是不同预测模型类型的一些标准损耗函数:

回归: 平均平方错误或”mean_squared_error”。
二元分类: 对数损耗,也称为交叉熵或”binary_crossentropy”。
多类分类: 多类对数丢失或”categorical_crossentropy”。
最常见的优化算法是随机梯度下降,但 Keras 还支持一套其他最先进的优化算法,这些算法在很少或没有配置时运行良好。

可能最常用的优化算法,因为它们通常更好的性能是:
Stochastic Gradient Descent: 或”sgd”,这需要调整学习速率和动量
ADAM: 或”adam”,这需要调整学习率。
RMSprop: 或”rmsprop”,这需要调整学习速率。
最后,除了损失函数之外,还可以指定在拟合模型时要收集的指标。通常,要收集的最有用的附加指标是分类问题的准确性。要收集的指标按数组中的名称指定。

例如:

model.compile(optimizer='sgd', loss='mean_squared_error', metrics=['accuracy'])

3、训练网络

编译网络后,它可以训练数据,这意味着调整训练数据集上的权重。

训练网络需要指定训练数据,包括输入模式矩阵 X 和匹配输出模式数组 y。

网络采用反向传播算法进行训练,根据编译模型时指定的优化算法和损失函数进行优化。

反向传播算法要求为网络指定训练轮数或对训练数据集。

每一轮训练可以划分为称为批处理的输入输出模式对。这将定义在一轮训练内更新权重。这也是一种效率优化,确保一次不会将太多的输入数据加载到内存中。

训练网络的最小示例如下:

history = model.fit(X, y, batch_size=10, epochs=100)

训练网络以后,将返回一个历史记录对象,该对象提供模型在训练期间性能的摘要。这包括在编译模型时指定的损失和任何其他指标,每一轮训练都记录下来。

训练网络可能需要很长时间,从数秒到数小时到数天,具体取决于网络的大小和训练数据的大小。

默认情况下,每一轮训练的命令行上将显示一个进度条。这可能给您带来太大的噪音,或者可能会给环境带来问题,例如,如果您是交互式笔记本或 IDE。

通过将verbose参数设置为 2,可以将显示的信息量减小到每轮训练的损失。您可以通过将verbose设置为 1 来关闭所有输出。例如:

history = model.fit(X, y, batch_size=10, epochs=100, verbose=0)

4、评估网络

一旦网络被训练,就可以评估它。

网络可以根据训练数据进行评估,但这不能像以前看到的所有这些数据那样,提供网络作为预测模型的性能的有用指示。

我们可以在单独的数据集上评估网络的性能,在测试期间看不到。这将提供网络在将来预测不可见数据时的性能估计。

该模型评估所有测试模式的损失,以及编译模型时指定的任何其他指标,如分类准确性。返回评估指标列表。

例如,对于使用精度指标编译的模型,我们可以在新数据集上对其进行如下评估:

loss, accuracy = model.evaluate(X, y)

与训练网络一样,提供了详细的输出,以给出模型评估的进度。我们可以通过将verbose参数设置为 0 来关闭此选项。

loss, accuracy = model.evaluate(X, y, verbose=0)

5、进行预测

一旦我们对拟合模型的性能感到满意,我们就可以用它来预测新数据。

这和使用一系列新输入模式在模型上调用predict() 函数一样简单。

例如:

predictions = model.predict(X)

预测将返回网络输出层提供的格式。

在回归问题的情况下,这些预测可能采用问题格式,由线性激活函数提供。

对于二进制分类问题,预测可能是第一个类的概率数组,可以通过舍入转换为 1 或 0。

对于多类分类问题,结果可能采用概率数组(假设一个热编码的输出变量),可能需要使用 argmax() NumPy 函数转换为单个类输出预测。

或者,对于分类问题,我们可以使用 predict_classes)函数,该函数将自动将 uncrisp 预测转换为清晰的整数类值。

predictions = model.predict_classes(X)

与拟合和评估网络一样,提供详细的输出,以给出模型进行预测的进展。我们可以通过将verbose参数设置为 0 来关闭此选项。

predictions = model.predict(X, verbose=0)

一个LSTM示例

让我们用一个简单的小例子将所有的模块整合到一起。
此示例将使用学习 10 个数字序列的简单问题。我们将向网络显示一个数字,如 0.0,并期望它预测 0.1。然后显示 0.1,并期望它预测 0.2,等等到 0.9。
定义网络: 我们将在网络中构建一个具有1个输入时间步长和1个输入特征的LSTM神经网络,在LSTM隐藏层中构建10个内存单元,在具有线性(默认)激活功能的完全连接的输出层中构建1个神经元。
编译网络: 我们将使用有效的ADAM优化算法与默认配置和平均平方误差损失函数,因为它是一个回归问题。
训练网络: 我们将网络训练1000轮,并使用与训练集中模式数相等的批处理大小。我们还将关闭所有详细输出。
评估网络: 我们将在训练数据集上评估网络。通常,我们会在测试或验证集上评估模型。
进行预测: 我们将对训练输入数据进行预测。同样,我们通常会对不知道正确答案的数据进行预测。
完整的代码如下:

# 使用LSTM学习序列数据示例
from pandas import DataFrame
from pandas import concat
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
# 创建序列
length = 10
sequence = [i/float(length) for i in range(length)]
print(sequence)
# 创建 X/y 对儿
df = DataFrame(sequence)
df = concat([df.shift(1), df], axis=1)
df.dropna(inplace=True)
# 转换LSTM格式
values = df.values
X, y = values[:, 0], values[:, 1]
X = X.reshape(len(X), 1, 1)
# 1. 定义网络
model = Sequential()
model.add(LSTM(10, input_shape=(1,1)))
model.add(Dense(1))
# 2. 编译网络
model.compile(optimizer='adam', loss='mean_squared_error')
# 3. 训练网络
history = model.fit(X, y, epochs=1000, batch_size=len(X), verbose=0)
# 4. 评估网络
loss = model.evaluate(X, y, verbose=0)
print(loss)
# 5. 进行预测
predictions = model.predict(X, verbose=0)
print(predictions[:, 0])

运行此示例将生成以下输出,显示 10 个数字的原始输入序列、对整个序列进行预测时网络的均平方误差损失以及每个输入模式的预测。

注意: 由于算法或评估过程具有随机性,或数值精度的差异,您的结果可能会有所不同。考虑运行示例几次,并比较平均结果。

我们可以看到序列学得很好,特别是如果我们把预测四舍五入到小数点位。
运行结果

总结

在这篇文章中,您发现了使用 Keras 库的 LSTM 循环神经网络的 5 步生命周期。

具体来说,您了解到:

1、如何定义、编译、拟合、评估和预测 Keras 中的 LSTM 网络。
2、如何选择激活函数和输出层配置的分类和回归问题。
3、如何开发和运行您的第一个LSTM模型在Keras。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/197213.html原文链接:https://javaforall.net

(0)
上一篇 2025年9月8日 上午11:43
下一篇 2025年9月8日 下午12:15


相关推荐

  • Java贪吃蛇全代码

    Java贪吃蛇全代码用Java编写精典小游戏——贪吃蛇!前言  我想贪吃蛇应该是不少90后和00后的童年(我本人是01年的),回想起从前偷偷拿着我爹的诺基亚在被窝里玩贪吃蛇,不禁感慨万分,时间飞逝,没想到10年后的我也可以自己做一个贪吃蛇了。    该程序主要实现了以下功能:  1.按空格开始游戏、暂停游戏或重新开始游戏。  2.方向键控制蛇移动的方向。  3.蛇吃掉食物可以增长,并增加游戏分数(不会加快游戏速度)。  4.蛇咬到自己会结束游戏。  5.蛇撞到游戏区域外会结束游戏。    接下来放

    2022年6月26日
    37
  • 广东电信 DNS 设置更改

    广东电信 DNS 设置更改因为未知原因(真的不知中国电信为何如此,有空打10000问问),原先广东电信用户可以使用的DNS服务器,如202.96.128.68202.96.128.110,不能使用了。因此,如果你的ADSL是使用路由方式共享上网的,并且手动设置了DNS服务器地址为以上ip,将会出现上不了网的情况。这就需要把DNS服务器地址更…

    2022年7月11日
    45
  • OpenClaw本地化部署安装指南

    OpenClaw本地化部署安装指南

    2026年3月19日
    1
  • postgresql 数据库 alter table alter column set default 的一些实践

    postgresql 数据库 alter table alter column set default 的一些实践os centos7 4db postgresql10 11 创建表后 有时需要对表进行 setdefault 或者 dropdefault 设置 版本 cat etc centos releaseCentO 4 1708 Core su postgres psql c selectversio

    2026年3月16日
    2
  • Android之include避免代码重复

    在做布局时,经常有些部分是重复的,比如title或者foot的地方,最简单的办法当然是直接复制过去,这里介绍include的用法,有过c++或者c经验的同学一看就明白了,就是把另一个布局包含进来.先看下实现的效果:里面上下各有两个文字布局,是用include包含进去的,直接看代码activity_main.xml:

    2022年3月11日
    49
  • Linux rpm安装jdk1.8

    Linux rpm安装jdk1.8前言每次需要配置JDK的时候都需要去网上搜一下,这次专门写下博客以备后用,虽然这个博客实在是太!简!单!了!亲测CentOS6,CentOS7都没有问题第一步:卸载系统自带的JDKrpm-qa|grepjava#xxxyyyzzz为你要卸载的插件,插件之间以空格隔开rpm-e–nodepsxxxyyyzzz第二步:安装JDK1.8…

    2022年6月11日
    32

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号