(深度学习)Pytorch之dropout训练

(深度学习)Pytorch之dropout训练(深度学习)Pytorch学习笔记之dropout训练Dropout训练实现快速通道:点我直接看代码实现Dropout训练简介在深度学习中,dropout训练时我们常常会用到的一个方法——通过使用它,我们可以可以避免过拟合,并增强模型的泛化能力。通过下图可以看出,dropout训练训练阶段所有模型共享参数,测试阶段直接组装成一个整体的大网络:那么,我们在深度学习的有力工具——Pytor…

大家好,又见面了,我是你们的朋友全栈君。

(深度学习)Pytorch学习笔记之dropout训练

Dropout训练实现快速通道:点我直接看代码实现

Dropout训练简介

在深度学习中,dropout训练时我们常常会用到的一个方法——通过使用它,我们可以可以避免过拟合,并增强模型的泛化能力。

通过下图可以看出,dropout训练训练阶段所有模型共享参数,测试阶段直接组装成一个整体的大网络:
在这里插入图片描述

那么,我们在深度学习的有力工具——Pytorch中如何实现dropout训练呢?

易错大坑

网上查找到的很多实现都是这种形式的:

        out = F.dropout(out, p=0.5)

这种形式的代码非常容易误导初学者,给人带来很大的困扰:

  • 首先,这里的F.dropout实际上是torch.nn.functional.dropout的简写(很多文章都没说清这一点,就直接给个代码),我尝试了一下我的Pytorch貌似无法使用,可能是因为版本的原因。
  • 其次,torch.nn.functional.dropout()还有个大坑:F.dropout()相当于引用的一个外部函数,模型整体的training状态变化也不会引起F.dropout这个函数的training状态发生变化。因此,上面的代码实质上就相当于out = out

因此,如果你非要使用torch.nn.functional.dropout的话,推荐的正确方法如下(这里默认你已经import torch.nn as nn了):

       out = nn.functional.dropout(out, p=0.5, training=self.training)

推荐代码实现方法

这里更推荐的方法是:nn.Dropout(p),其中p是采样概率。nn.Dropout实际上是对torch.nn.functional.dropout的一个包装, 也将self.training传入了其中,可以有效避免前面所说的大坑。

下面给出一个三层神经网络的例子:

import torch.nn as nn


input_size = 28 * 28   
hidden_size = 500   
num_classes = 10    


# 三层神经网络
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入层到影藏层
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  # 影藏层到输出层
        self.dropout = nn.Dropout(p=0.5)  # dropout训练

    def forward(self, x):
        out = self.fc1(x)
        out = self.dropout(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out
   

model = NeuralNet(input_size, hidden_size, num_classes)
model.train()
model.eval()

另外还有一点需要说明的是,训练阶段随机采样时需要用model.train(),而测试阶段直接组装成一个整体的大网络时需要使用model.eval():

  • 如果你二者都没使用的话,默认情况下实际上是相当于使用了model.train(),也就是开启dropout随机采样了——这样你如果你是在测试的话,准确率可能会受到影响。
  • 如果你不希望开启dropout训练,想直接以一个整体的大网络来训练,不需要重写一个网络结果,而只需要在训练阶段开启model.eval()即可。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/133466.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月1日 下午7:40
下一篇 2022年5月1日 下午7:40


相关推荐

  • getClass()和getClassLoader()区别 以及ClassLoader详解及用途(文件加载,类加载)

    getClass()和getClassLoader()区别 以及ClassLoader详解及用途(文件加载,类加载)获得ClassLoader的几种方法可以通过如下3种方法得到ClassLoader  this.getClass().getClassLoader(); // 使用当前类的ClassLoader  Thread.currentThread().getContextClassLoader(); // 使用当前线程的ClassLoader  ClassLoader.getSystemCla

    2022年5月4日
    78
  • 从PSF到PMF,再到PRF

    从PSF到PMF,再到PRFPMF的概念,我想大家都熟知了,但最近我经常看到,也感觉在PMF之前,还有个更早期的概念,可以用PSF来描述。简单对比就是:Problem-Solution-Fit:价值…

    2022年5月23日
    40
  • 【金融科技前沿】【长文】金融监管、监管科技以及银行业监管报送概述「建议收藏」

    【金融科技前沿】【长文】金融监管、监管科技以及银行业监管报送概述「建议收藏」上周金融科技前沿课程的主题是《监管科技》,韩海燕老师从金融监管引入,介绍了我国的金融监管体系,接着进入监管科技的详细讲解。我觉得最主要的是弄清楚监管科技的定义,以及在实际的银行业应用场景中具体的运作流程是怎么样的。韩老师讲的很全面,将ABCD等金融科技手段在监管系统中是如何运作的讲的很清楚,收获颇丰,但是比较少涉及到监管的对象和内容,仍没有很清楚监管机构是要监管什么东西?监管机构要求银行金融业机构报送的资料有哪些?这些报送要求的目的分别是什么?所以这篇文章分为三个部分,一是介绍金融监管,二是介绍监科技,三.

    2022年5月6日
    271
  • Authentication failure. Retrying – 彻底解决vagrant up时警告

    Authentication failure. Retrying – 彻底解决vagrant up时警告

    2022年2月9日
    97
  • java .endswith_Java endsWith() 方法

    java .endswith_Java endsWith() 方法JavaendsWith 方法 endsWith 方法用于测试字符串是否以指定的后缀结束 语法 publicboolea Stringsuffix 参数 suffix 指定的后缀 返回值如果参数表示的字符序列是此对象表示的字符序列的后缀 则返回 true 否则返回 false 注意 如果参数是空字符串 或者等于此 String 对象 用 equals Object

    2026年3月16日
    2
  • JavaScript中的JSON序列化/反序列化

    JavaScript中的JSON序列化/反序列化JSON1 JSON 简介 2 JSON 与 JSObject 区别 3 对象序列化 3 1JSON 序列化 3 2JSON 反序列化 1 JSON 简介 JSON JavaScriptOb JavaScript 对象简谱 是一种轻量级的数据交换格式 JSON 是一种语法 用来序列化对象 数组 数值 字符串 布尔值和 null 不包含 undefined JSON 可以描述三种格式的数据 object 无序的 键 值 集合 array 有序的值集合 value 具体可参考

    2026年3月18日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号