pytorch中RNN参数的详细解释

pytorch中RNN参数的详细解释总述 第一次看到这个函数时 脑袋有点懵 总结了下总共有五个问题 1 这个 input size 是啥 要输入啥 featurenum 又是啥 2 这个 hidden size 是啥 要输入啥 featurenum 又是啥 3 不是说 RNN 会有很多个节点连在一起的吗 这怎么定义连接的节点数呢 4 num layer 中说的 stack 是怎么 stack 的 5 怎么输出会有两个东西呀 outp

总述:

第一次看到这个函数时,脑袋有点懵,总结了下总共有五个问题:

1.这个input_size是啥?要输入啥?feature num又是啥?

2.这个hidden_size是啥?要输入啥?feature num又是啥?

3.不是说RNN会有很多个节点连在一起的吗?这怎么定义连接的节点数呢?

4.num_layer中说的stack是怎么stack的?

5.怎么输出会有两个东西呀output,hn

此篇博客介绍pytorch中RNN的一些参数,并且解决以上五个问题

1.Pytorch中的RNN

pytorch中RNN参数的详细解释

pytorch中RNN参数的详细解释

2.input_size是啥?

说白了input_size无非就是你输入RNN的维度,比如说NLP中你需要把一个单词输入到RNN中,这个单词的编码是300维的,那么这个input_size就是300.这里的input_size其实就是规定了你的输入变量的维度。用f(wX+b)来类比的话,这里输入的就是X的维度。

ps:  “维度”可能造成了大家的误解,专业一点的说法是“特征数量”或者“通道数”。如果你有一个【bs * sequence_length * hidden_dim】的向量,我这里的维度指的是这个“hidden_dim”.

3.hidden_size是啥?

和最简单的BP网络一样的,每个RNN的节点实际上就是一个BP嘛,包含输入层,隐含层,输出层。这里的hidden_size呢,你可以看做是隐含层中,隐含节点的个数。(讲到这里还不清楚的,请复习一下最简单的三层神经网络的架构)。

pytorch中RNN参数的详细解释

那个输入层的三个节点代表输入维度为3,也就是input_size=3,然后这个hidden_size就是5了。当然这是是对于RNN某一个节点而言的,那么如何规定RNN的节点个数呢?

4.如何规定节点个数?

事实上,节点个数并不需要规定,你的输入序列是这样子的,[x1,x2,x3,x4,x5],那么input_size呢就是你的xi的维度,而你的RNN的节点数呢,就是由你的序列长度决定的,在这里我们的序列长度是5,所以会有5个节点。那么问题来了,我咋知道你的序列长度呢?pytorch里面不是只有input_size的参数吗?实际上,你声明RNN是这样声明的

self.encoder = nn.RNN(input_size=300,hidden_size=128,dropout=0.5)

但是你用的时候;

output,hn = self.encoder(encoder_input,encoder_hidden)

你会把你的数据丢进去吧,也就是你把encoder_input这一整个序列丢进去了,那么序列长度他不就知道了?

5.num_layers是啥?

一开始你是不是以为这个就是RNN的节点数呀,hhh,然而并不是:),如果num_layer=2的话,表示两个RNN堆叠在一起。那么怎么堆叠的呢?

如果是num_layer==1的话:

è¿éåå¾çæè¿°

如果num_layer==2的话:

è¿éåå¾çæè¿°

ok了~最后再来看看最后一个问题

6.hn,output分别是啥?

看图找答案:

pytorch中RNN参数的详细解释

hn就是RNN的最后一个隐含状态。

经评论区的提醒,output是最后一层所有节点的hn集合,上图有一点点错误,请见谅~

如果还不理解的话,可以参见某个大佬总结的RNN,%%%%%%%%

RNN_机器学习/NLP/搜广推/算法开发工程/大数据-CSDN博客_rnn

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/175871.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月26日 下午11:13
下一篇 2026年3月26日 下午11:14


相关推荐

  • 数据分析模型(二):模糊聚类分析方法及实例(附完整代码)

    数据分析模型(二):模糊聚类分析方法及实例(附完整代码)聚类分析是数据挖掘技术中的一种重要的方法 可以作为一个独立的工具来获得数据分布情况 它广泛地应用于模式识别 数据分析 图像处理 生物学 经济学等许多领域 聚类分析方法是数理统计中研究 物以类聚 的一种多元分析方法 及用数学定量地确定样品的亲疏关系 从而客观地分型化类 由于事物本身在很多情况下都带有模糊性 因此把模糊数学方法引入聚类分析 能使分类更切合实际 我们所应用的模糊聚类方法是基于模糊相似关

    2025年11月16日
    6
  • UserDetailsService小记

    UserDetailsService小记Resourcepriv 当你采用上面的方式引入 UserDetailsS 时 你需要像下面一样在 Service 中加上名称 因为 UserDetailsS 中会先产生一个 InMemoryUser 它也是 UserDetailsS 的实现类 Spring 选择了 InMemoryUser 这时 U

    2026年3月17日
    1
  • Java的输入输出语句_c语言有没有输入输出语句

    Java的输入输出语句_c语言有没有输入输出语句一、概述  输入输出可以说是计算机的基本功能。作为一种语言体系,java中主要按照流(stream)的模式来实现。其中数据的流向是按照计算机的方向确定的,流入计算机的数据流叫做输入流(inputStream),由计算机发出的数据流叫做输出流(outputStream)。Java语言体系中,对数据流的主要操作都封装在java.io包中,通过java.io包中的类可以实现计算机对数据的输入、输出操作…

    2022年4月19日
    41
  • 简单剖析B树(B-Tree)与B+树

    简单剖析B树(B-Tree)与B+树注意 首先需要说明的一点是 B 树就是 B 树 没有所谓的 B 减树引言 我们都知道二叉查找树的查找的时间复杂度是 logN 其查找效率已经足够高了 那为什么还有 树和 树的出现呢 难道它两的时间复杂度比二叉查找树还小吗 答案当然不是 树和 树的出现是因为另外一个问题 那就是磁盘 众所周知 操作的效率很低 那么 当在大量数据存储中 查询时我们不能一下子将所有数据加载到

    2026年3月17日
    3
  • 图论完备之旅

    图论完备之旅

    2021年11月15日
    52
  • swal弹窗,sweetalert2具有相同功能的多个swal[通俗易懂]

    swal弹窗,sweetalert2具有相同功能的多个swal[通俗易懂]I’dliketomakeaconditionandcallaswalforeachone(Sweetalert2).Butonlyoneoftheswalruns.HowcanIdoit?functionvalidateEmail(email){varregex=/\S+@\S+\.\S+/;returnregex.test(emai…

    2025年5月23日
    6

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号