(深度学习)Pytorch之dropout训练

(深度学习)Pytorch之dropout训练(深度学习)Pytorch学习笔记之dropout训练Dropout训练实现快速通道:点我直接看代码实现Dropout训练简介在深度学习中,dropout训练时我们常常会用到的一个方法——通过使用它,我们可以可以避免过拟合,并增强模型的泛化能力。通过下图可以看出,dropout训练训练阶段所有模型共享参数,测试阶段直接组装成一个整体的大网络:那么,我们在深度学习的有力工具——Pytor…

大家好,又见面了,我是你们的朋友全栈君。

(深度学习)Pytorch学习笔记之dropout训练

Dropout训练实现快速通道:点我直接看代码实现

Dropout训练简介

在深度学习中,dropout训练时我们常常会用到的一个方法——通过使用它,我们可以可以避免过拟合,并增强模型的泛化能力。

通过下图可以看出,dropout训练训练阶段所有模型共享参数,测试阶段直接组装成一个整体的大网络:
在这里插入图片描述

那么,我们在深度学习的有力工具——Pytorch中如何实现dropout训练呢?

易错大坑

网上查找到的很多实现都是这种形式的:

        out = F.dropout(out, p=0.5)

这种形式的代码非常容易误导初学者,给人带来很大的困扰:

  • 首先,这里的F.dropout实际上是torch.nn.functional.dropout的简写(很多文章都没说清这一点,就直接给个代码),我尝试了一下我的Pytorch貌似无法使用,可能是因为版本的原因。
  • 其次,torch.nn.functional.dropout()还有个大坑:F.dropout()相当于引用的一个外部函数,模型整体的training状态变化也不会引起F.dropout这个函数的training状态发生变化。因此,上面的代码实质上就相当于out = out

因此,如果你非要使用torch.nn.functional.dropout的话,推荐的正确方法如下(这里默认你已经import torch.nn as nn了):

       out = nn.functional.dropout(out, p=0.5, training=self.training)

推荐代码实现方法

这里更推荐的方法是:nn.Dropout(p),其中p是采样概率。nn.Dropout实际上是对torch.nn.functional.dropout的一个包装, 也将self.training传入了其中,可以有效避免前面所说的大坑。

下面给出一个三层神经网络的例子:

import torch.nn as nn


input_size = 28 * 28   
hidden_size = 500   
num_classes = 10    


# 三层神经网络
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入层到影藏层
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  # 影藏层到输出层
        self.dropout = nn.Dropout(p=0.5)  # dropout训练

    def forward(self, x):
        out = self.fc1(x)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out
   

model = NeuralNet(input_size, hidden_size, num_classes)
model.train()
model.eval()

另外还有一点需要说明的是,训练阶段随机采样时需要用model.train(),而测试阶段直接组装成一个整体的大网络时需要使用model.eval():

  • 如果你二者都没使用的话,默认情况下实际上是相当于使用了model.train(),也就是开启dropout随机采样了——这样你如果你是在测试的话,准确率可能会受到影响。
  • 如果你不希望开启dropout训练,想直接以一个整体的大网络来训练,不需要重写一个网络结果,而只需要在训练阶段开启model.eval()即可。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/133466.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 纯CSS实现表单验证[通俗易懂]

    纯CSS实现表单验证[通俗易懂]纯CSS实现表单验证

    2022年4月21日
    77
  • typora导出pdf文件缺失

    typora导出pdf文件缺失typora导出pdf文件缺失,原因很可能是在正文中存在<script>标签,比如:当做了脚本执行了,所以很可能在此之后的内容都会确实,调整方案为,用“包裹起来script标签:

    2022年5月20日
    34
  • 如何编程简单的病毒_永恒之蓝病毒如何传播

    如何编程简单的病毒_永恒之蓝病毒如何传播永恒之蓝病毒2018《开发者经济学:开发者国家现状》第15版已发布,它提供了一些非常有趣的见解。SlashData在2018年5月至6月的七周时间内进行了大规模调查,调查覆盖了167个国家的20,500多名开发人员。开发者国家状况报告有六个主要重点领域,包括:数据科学开发人员有兴趣破坏全球经济,但仍致力于定制软件编程语言社区–更新不断发展的技术和新渠道…

    2022年10月9日
    3
  • 【Python】python面试题

    【Python】python面试题一些Python面试题注:本面试题来源于网络,部分内容摘自http://www.cnblogs.com/goodhacker/p/3366618.html1.(1)python下多线程的限制以及

    2022年7月5日
    24
  • js的数据类型有哪些?[通俗易懂]

    js的数据类型有哪些?[通俗易懂]数据类型一、数据类型:基本数据类型(值类型):字符串(String)、数字(Number)、布尔(Boolean)、对空(Null)、未定义(Undefined)。引用数据类型(对象类型):对象(Object)、数组(Array)、函数(Function)。特殊的对象:正则(RegExp)和日期(Date)。特殊类型:underfined未定义、Null空对象、Infinate无穷、NAN非数字基本数据类型的值直接在栈内存中存储,值与值之间独立存在,修改一个变量不会影响.

    2025年9月19日
    4
  • SQL 获取当前系统时间

    SQL 获取当前系统时间SQL获取当前系统时间

    2022年10月19日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号