PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景

PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景PyTorch 中有一些基础概念在构建网络的时候很重要 比如 nn Module nn ModuleList nn Sequential 这些类我们称之为容器 containers 因为我们可以添加模块 module 到它们之中 这些容器之间很容易混淆 本文中我们主要学习一下 nn ModuleList 和 nn Sequential 并判断在什么时候用哪一个比较合适

PyTorch 中有一些基础概念在构建网络的时候很重要,比如 nn.Module, nn.ModuleList, nn.Sequential,这些类我们称之为容器 (containers),因为我们可以添加模块 (module) 到它们之中。这些容器之间很容易混淆,本文中我们主要学习一下 nn.ModuleList 和 nn.Sequential,并判断在什么时候用哪一个比较合适。本文中的例子使用的是 PyTorch 1.0 版本。

nn.ModuleList

首先说说 nn.ModuleList 这个类,你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。描述看起来很枯燥,我们来看几个例子。

第一个网络,我们先来看看使用 nn.ModuleList 来构建一个小型网络,包括3个全连接层:

class net1(nn.Module): def __init__(self): super(net1, self).__init__() self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)]) def forward(self, x): for m in self.linears: x = m(x) return x net = net1() print(net) # net1( # (modules): ModuleList( # (0): Linear(in_features=10, out_features=10, bias=True) # (1): Linear(in_features=10, out_features=10, bias=True) # ) # ) for param in net.parameters(): print(type(param.data), param.size()) # 
   
     torch.Size([10, 10]) 
    # 
   
     torch.Size([10]) 
    # 
   
     torch.Size([10, 10]) 
    # 
   
     torch.Size([10]) 
    

我们可以看到,这个网络包含两个全连接层,他们的权重 (weithgs) 和偏置 (bias) 都在这个网络之内。接下来我们看看第二个网络,它使用 Python 自带的 list:

class net2(nn.Module): def __init__(self): super(net2, self).__init__() self.linears = [nn.Linear(10,10) for i in range(2)] def forward(self, x): for m in self.linears: x = m(x) return x net = net2() print(net) # net2() print(list(net.parameters())) # [] 

显然,使用 Python 的 list 添加的全连接层和它们的 parameters 并没有自动注册到我们的网络中。当然,我们还是可以使用 forward 来计算输出结果。但是如果用 net2 实例化的网络进行训练的时候,因为这些层的 parameters 不在整个网络之中,所以其网络参数也不会被更新。

好,看到这里,我们大致明白了 nn.ModuleList 是干什么的了:它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。但是,我们需要注意到,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,比如:

class net3(nn.Module): def __init__(self): super(net3, self).__init__() self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)]) def forward(self, x): x = self.linears[2](x) x = self.linears[0](x) x = self.linears[1](x) return x net = net3() print(net) # net3( # (linears): ModuleList( # (0): Linear(in_features=10, out_features=20, bias=True) # (1): Linear(in_features=20, out_features=30, bias=True) # (2): Linear(in_features=5, out_features=10, bias=True) # ) # ) input = torch.randn(32, 5) print(net(input).shape) # torch.Size([32, 30]) 

根据 net3 的结果,我们可以看出来这个 ModuleList 里面的顺序并不能决定什么,网络的执行顺序是根据 forward 函数来决定的。如果你非要 ModuleList 和 forward 中的顺序不一样, PyTorch 表示它无所谓,但以后 review 你代码的人可能会意见比较大。

我们再考虑另外一种情况,既然这个 ModuleList 可以根据序号来调用,那么一个模块是否可以在 forward 函数中被调用多次呢?答案当然是可以的,但是,被调用多次的模块,是使用同一组 parameters 的,也就是它们的参数是完全一样的,无论你之后怎么更新。例子如下,虽然在 forward 中我们用了 nn.Linear(10,10) 两次,但是它们只有一组参数。这么做有什么用处呢,我目前没有想到…

class net4(nn.Module): def __init__(self): super(net4, self).__init__() self.linears = nn.ModuleList([nn.Linear(5, 10), nn.Linear(10, 10)]) def forward(self, x): x = self.linears[0](x) x = self.linears[1](x) x = self.linears[1](x) return x net = net4() print(net) # net4( # (linears): ModuleList( # (0): Linear(in_features=5, out_features=10, bias=True) # (1): Linear(in_features=10, out_features=10, bias=True) # ) # ) for name, param in net.named_parameters(): print(name, param.size()) # linears.0.weight torch.Size([10, 5]) # linears.0.bias torch.Size([10]) # linears.1.weight torch.Size([10, 10]) # linears.1.bias torch.Size([10]) 

nn.Sequential

现在我们来研究一下 nn.Sequential,不同于 nn.ModuleList,它已经实现的 forward 函数,而且里面的模块是按照顺序进行排列的,所以我们必须确保前一个模块的输出大小和下一个模块的输入大小是一致的,如下面的例子所示:

class net5(nn.Module): def __init__(self): super(net5, self).__init__() self.block = nn.Sequential(nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU()) def forward(self, x): x = self.block(x) return x net = net5() print(net) # net5( # (block): Sequential( # (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) # (1): ReLU() # (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) # (3): ReLU() # ) # ) 

下面给出了两个初始化的例子,来自于 官网教程。第二个初始化的时候我们用到了 OrderedDict 来指定每个 module 的名字,而不是采用默认的命名方式 (按序号 0,1,2,3…) 。

# Example of using Sequential model1 = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) print(model1) # Sequential( # (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) # (1): ReLU() # (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) # (3): ReLU() # ) # Example of using Sequential with OrderedDict import collections model2 = nn.Sequential(collections.OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ])) print(model2) # Sequential( # (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) # (relu1): ReLU() # (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) # (relu2): ReLU() # ) 

有同学可能发现了,诶,你这个 model1 和 从类 net5 实例化来的 net 有什么区别吗?是没有的。这两个网络是相同的,因为 nn.Sequential 就是一个 nn.Module 的子类,也就是 nn.Module 所有的方法 (method) 它都有。并且直接使用 nn.Sequential 不用写 forward 函数,因为它内部已经帮你写好了。

这时候有同学该说了,既然 nn.Sequential 这么好,我以后都直接用它了。如果你确定 nn.Sequential 里面的顺序是你想要的,而且不需要再添加一些其他处理的函数 (比如 nn.functional 里面的函数,nn 与 nn.functional 有什么区别? ),那么完全可以直接用 nn.Sequential。这么做的代价就是失去了部分灵活性,毕竟不能自己去定制 forward 函数里面的内容了。

一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。

nn.ModuleList 和 nn.Sequential: 到底该用哪个

前边我们已经简单介绍了这两个类,现在我们来讨论一下在两个不同的场景中,选择哪一个比较合适。

场景一,有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们,比如:

layers = [nn.Linear(10, 10) for i in range(5)] 

这个时候,很自然而然地,我们会想到使用 ModuleList,像这样:

class net6(nn.Module): def __init__(self): super(net6, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)]) def forward(self, x): for layer in self.linears: x = layer(x) return x net = net6() print(net) # net6( # (linears): ModuleList( # (0): Linear(in_features=10, out_features=10, bias=True) # (1): Linear(in_features=10, out_features=10, bias=True) # (2): Linear(in_features=10, out_features=10, bias=True) # ) # ) 

这个是比较一般的方法,但如果不想这么麻烦,我们也可以用 Sequential 来实现,如 net7 所示!注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素。所以在 场景一 中,我个人觉得使用 net7 这种方法比较方便和整洁

class net7(nn.Module): def __init__(self): super(net7, self).__init__() self.linear_list = [nn.Linear(10, 10) for i in range(3)] self.linears = nn.Sequential(*self.linears_list) def forward(self, x): self.x = self.linears(x) return x net = net7() print(net) # net7( # (linears): Sequential( # (0): Linear(in_features=10, out_features=10, bias=True) # (1): Linear(in_features=10, out_features=10, bias=True) # (2): Linear(in_features=10, out_features=10, bias=True) # ) # ) 

下面我们考虑 场景二,当我们需要之前层的信息的时候,比如 ResNets 中的 shortcut 结构,或者是像 FCN 中用到的 skip architecture 之类的,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList 比较方便,一个非常简单的例子如下:

class net8(nn.Module): def __init__(self): super(net8, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 20), nn.Linear(20, 30), nn.Linear(30, 50)]) self.trace = [] def forward(self, x): for layer in self.linears: x = layer(x) self.trace.append(x) return x net = net8() input = torch.randn(32, 10) output = net(input) for each in net.trace: print(each.shape) # torch.Size([32, 20]) # torch.Size([32, 30]) # torch.Size([32, 50]) 

我们使用了一个 trace 的列表来储存网络每层的输出结果,这样如果以后的层要用的话,就可以很方便的调用了。

总结

本文中我们通过一些实例学习了 ModuleList 和 Sequential 这两种 nn containers,ModuleList 就是一个储存各种模块的 list,这些模块之间没有联系,没有实现 forward 功能,但相比于普通的 Python list,ModuleList 可以把添加到其中的模块和参数自动注册到网络上。而Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部 forward 功能已经实现,可以使代码更加整洁。在不同场景中,如果二者都适用,那就看个人偏好了。非常推荐大家看 PyTorch 官方的 TorchVision 下面的模型实现的代码,能学到很多构建网络的技巧。

Reference:

  1. nn.ModuleList 和 Sequential 由来、用法和实例 —— 写网络模型
  2. The difference in usage between nn.ModuleList and python list
  3. When should I use nn.ModuleList and when should I use nn.Sequential?
  4. Pytorch: how and when to use Module, Sequential, ModuleList and ModuleDict
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/216129.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 下午12:29
下一篇 2026年3月18日 下午12:29


相关推荐

  • 2016年总结-JAVA程序员

    2016年总结-JAVA程序员

    2020年11月12日
    181
  • 一阶倒立摆分析_倒立摆受力分析

    一阶倒立摆分析_倒立摆受力分析摆的运动是两种运动的叠加:1.平动,包含x方向和y方向。2.转动,转轴为质心。尽管物理上的转轴是其端点,但这个端点同时也是摆的受力点。在端点(非中心)施加垂直于摆臂的力,摆将绕其质心转动。  因为摆的重力作用于其转轴(质心),因此摆自身的重力对摆不施加力矩。这可以算作将质心作为转轴来分析的一个优势。   …

    2022年8月18日
    8
  • 50个有用的国外扁平化设计PSD素材

    50个有用的国外扁平化设计PSD素材最近看了很多扁平化设计作品 然后自己新项目尝试下 FlatDesign 发现看似简单的图形但设计起来并不是想象中那么容易 特别是配色方面 很难整理一个好的方案 好在还有一个扁平化配色方案库 这里的配色给我很多配色灵感 nbsp nbsp 最近看了很多扁平化设计作品 然后自己新项目尝试下 FlatDesign 发现看似简单的图形但设计起来并不是想象中那么容易 特别是配色方面 很难整理一个好的

    2026年3月18日
    2
  • Centos7 postfix dovecot安装配置

    Centos7 postfix dovecot安装配置基本流程及软件版本最近在为公司搭建私有服务器环境 调研了一些开源或付费软件 最后选择自己配置软件环境 以下为配置过程及测试 Postfix 一种邮件传输代理软件 通常用来发送邮件 Dovcot 邮件检索代理软件 通常用来接收邮件 发送流程 客户端 MailClient 发出邮件 amp amp amp gt Postfix 接收 通过 Dovecot 进行认证 查询数据库是否存在用户 amp amp amp gt Dovecot 把

    2026年3月17日
    2
  • Java中&0xFF是什么意思?计算机的原码、补码和反码

    Java中&0xFF是什么意思?计算机的原码、补码和反码公司项目中有向MCU发数据的代码,新来的同事对其中的&0xFF很不理解,我解释了很多遍他还是蒙圈状态,可能我的表达能力太差,想想还是用一篇博客来详细说明吧,代码如下:更新:07月10日,有个小伙伴对这种操作各种不习惯,怎么解释他都想不明白,所以增加了代码注释为什么要加上“&0xFF”?拆分理解下0xFF是16进制的表达方式,F是15;十进制为:255,二进制为:11111111

    2022年6月19日
    850
  • linux删除软连接命令_linux删除链接文件夹

    linux删除软连接命令_linux删除链接文件夹linux删除软链接的正确做法

    2026年4月16日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号