深度学习 warmup 策略

深度学习 warmup 策略一 介绍 warmup 顾名思义就是热身 在刚刚开始训练时以很小的学习率进行训练 使得网络熟悉数据 随着训练的进行学习率慢慢变大 到了一定程度 以设置的初始学习率进行训练 接着过了一些 inter 后 学习率再慢慢变小 学习率变化 上升 平稳 下降 具体步骤 启用 warmup 设置 warmupsetp 一般等于 epoch inter per epoch 当 step 小于 warmupsetp 时 学习率等于基础学习率 当前 step warmup step 由于后者

一、介绍

     warmup顾名思义就是热身,在刚刚开始训练时以很小的学习率进行训练,使得网络熟悉数据,随着训练的进行学习率慢慢变大,到了一定程度,以设置的初始学习率进行训练,接着过了一些inter后,学习率再慢慢变小;学习率变化:上升——平稳——下降;

具体步骤:

        启用warm up,设置warm up setp(一般等于epoch*inter_per_epoch),当step小于warm up setp时,学习率等于基础学习率×(当前step/warmup_step),由于后者是一个小于1的数值,因此在整个warm up的过程中,学习率是一个递增的过程!当warm up结束后,学习率以基础学习率进行训练,再学习率开始递减

二、使用场景

1、当网络非常容易nan时候,采用warm up进行训练,可使得网络正常训练;

2、如果训练集损失很低,准确率高,但测试集损失大,准确率低,可用warm up;具体可看:https://blog.csdn.net/u0/article/details/

三、有效原因

这个问题目前还没有被充分证明,目前效果有:

  • 有助于减缓模型在初始阶段对mini-batch的提前过拟合现象,保持分布的平稳
  • 有助于保持模型深层的稳定性

在训练期间有如下情况:

     1、在训练的开始阶段,模型权重迅速改变

      2、mini-batch size较小,样本方差较大

        第一种因为刚刚开始的时候,模型对数据的“分布”理解为零,或者是说“均匀分布”(初始化一般都是以均匀分布来初始化);

在第一轮训练的时候,每个数据对模型来说都是新的,随着训练模型会很快地进行数据分布修正,这时候学习率就很大,很有可能在刚刚开始就会导致过拟合,后期需要要通过多轮训练才能拉回来。当训练了一段时间(比如两轮、三轮)后,模型已经对每个数据过几遍了,或者说对当前的batch而言有了一些正确的先验,较大的学习率就不那么容易会使模型学偏,所以可以适当调大学习率。这个过程就也就是warmup。

        那后期为什么学习率又要减小呢?这就是我们正常训练时候,学习率降低有助于更好的收敛,当模型学习到一定的 程度,模型的分布就学习的比较稳定了。如果还用较大的学习率,就会破坏这种稳定性,导致网络波动比较大,现在已经十分接近了最优了,为了靠近这个最优点,我就就要很小的学习率

第二原因:如果有mini-batch内的数据分布方差特别大,这就会导致模型学习剧烈波动,使其学得的权重很不稳定,这在训练初期最为明显,最后期较为缓解

所以由于上面这两个原因,我们不能随便成倍减少学习率;

在resnet文章中,有说到如果一开始就用大的学习率,虽然最终会收敛,但之后测试准确率还是不会提高;如果用了warmup,在收敛后还能有所提高。也就是说,用warm up和不用warm up达到的收敛点,对之后模型能够达到最优点有影响。这说明不用warm up收敛到的点比用warm up收敛到的点更差。这也说明,如果刚刚开始学偏了的权重后面都拉不回来;

那么为什么以前神经网络没用warm up技巧呢?

主要原因是:

    1、y以前网络不够大、不够深

     2、数据集普遍较小

 

四、部分实现代码

1、tensorflow

if warmup: warmup_steps = int(batches_per_epoch * 5) warmup_lr = (initial_learning_rate * tf.cast(global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)) return tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)

或者 ,看代码来理解,这是摘抄的;

with tf.name_scope('learn_rate'): self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step') warmup_steps = tf.constant(self.warmup_periods * self.steps_per_period, dtype=tf.float64, name='warmup_steps') # warmup_periods epochs train_steps = tf.constant((self.first_stage_epochs + self.second_stage_epochs) * self.steps_per_period, dtype=tf.float64, name='train_steps') self.learn_rate = tf.cond( pred=self.global_step < warmup_steps, true_fn=lambda: self.global_step / warmup_steps * self.learn_rate_init, false_fn=lambda: self.learn_rate_end + 0.5 * (self.learn_rate_init - self.learn_rate_end) * ( 1 + tf.cos((self.global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi))) global_step_update = tf.assign_add(self.global_step, 1.0) """ 训练分为两个阶段,第一阶段里前面又划分出一段作为“热身阶段”: 热身阶段:learn_rate = (global_step / warmup_steps) * learn_rate_init 其他阶段:learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) * ( 1 + tf.cos((global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi))

2、pytorch

该项目中有pytorch实现

https://github.com/ruinmessi/ASFF/issues/65

正文

1. 背景

学习率是最影响性能的超参数之一,如果我们只能调整一个超参数,那么最好的选择就是它。 其实在我们的大多数情况下,遇到 loss 变成 NaN 的情况大多数是由于学习率选择不当引起的

2. 学习率的设置 — “不同阶段不同值:上升 -> 平稳 -> 下降”

由于神经网络在刚开始训练的时候是非常不稳定的,因此刚开始的学习率应当设置得很低很低,这样可以保证网络能够具有良好的收敛性。但是较低的学习率会使得训练过程变得非常缓慢,因此这里会采用以较低学习率逐渐增大至较高学习率的方式实现网络训练的“热身”阶段,称为 warmup stage。但是如果我们使得网络训练的 loss 最小,那么一直使用较高学习率是不合适的,因为它会使得权重的梯度一直来回震荡,很难使训练的损失值达到全局最低谷。这个代码采用了 cosine 的衰减方式,这个阶段可以称为 consine decay stage。

3. tf-yolov3作者的相关源码

with tf.name_scope('learn_rate'): self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step') warmup_steps = tf.constant(self.warmup_periods * self.steps_per_period, dtype=tf.float64, name='warmup_steps') # warmup_periods epochs train_steps = tf.constant((self.first_stage_epochs + self.second_stage_epochs) * self.steps_per_period, dtype=tf.float64, name='train_steps') self.learn_rate = tf.cond( pred=self.global_step < warmup_steps, true_fn=lambda: self.global_step / warmup_steps * self.learn_rate_init, false_fn=lambda: self.learn_rate_end + 0.5 * (self.learn_rate_init - self.learn_rate_end) * ( 1 + tf.cos((self.global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi))) global_step_update = tf.assign_add(self.global_step, 1.0) """ 训练分为两个阶段,第一阶段里前面又划分出一段作为“热身阶段”: 热身阶段:learn_rate = (global_step / warmup_steps) * learn_rate_init 其他阶段:learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) * ( 1 + tf.cos((global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi)) """ 

4. 应用场景

(1)训练出现NaN:当网络非常容易nan时候,采用warm up进行训练,可使得网络正常训练;

(2)过拟合:训练集损失很低,准确率高,但测试集损失大,准确率低,可用warm up;具体可看: Resnet-18-训练实验-warm up操作

5. 应用原理/优势来源

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/225795.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 上午8:32
下一篇 2026年3月17日 上午8:32


相关推荐

  • vue-cli 3.0之跨域请求devServer代理配置

    vue-cli 3.0之跨域请求devServer代理配置概念什么是同源策略同源策略是一种约定 它是浏览器最核心也最基本的安全功能 如果缺少了同源策略 则浏览器的正常功能可能都会受到影响 可以说 Web 是构建在同源策略基础之上的 浏览器只是针对同源策略的一种实现 所谓同源是指 协议 域名 端口都相同什么是跨域跨域就是不同源 就是不满足协议 域名 端口都相同的约定如 看下面的链接是否与 http www test com index ht

    2026年3月19日
    2
  • 词向量表示[通俗易懂]

    词向量表示[通俗易懂]1、语言表示语音中,用音频频谱序列向量所构成的矩阵作为模型的输入;在图像中,用图像的像素构成的矩阵数据作为模型的输入。这些都可以很好表示语音/图像数据。而语言高度抽象,很难刻画词语之间的联系,比如“麦克风”和“话筒”这样的同义词,从字面上也难以看出这两者意思相同,即“语义鸿沟”现象。1.1、分布假说上下文相似的词,其语义也相似。1.2、语言模型文本学习:词频、词的共现、词的搭配。语言模型判定一句话是否为自然语言。机器翻译、拼写纠错、音字转换、问答系统、语音识别等应用在得到若干候…

    2022年5月25日
    51
  • gridview属性_GridView

    gridview属性_GridViewGridView在生成HTML代码的时候会自动加上style=”border-collapse:collapse;”以及border=1,rules=”all”这些属性,这些在IE下都没什么影响,但是在FF下就会影响显示,style=”border-collapse:collapse;”;是由于设置了CellSpacing=”0″产生的,当设置CellSpacing=”1″后就没有,可以去掉sty

    2026年3月11日
    6
  • 使用MySQL Workbench建立数据库,建立新的表,向表中添加数据

    使用MySQL Workbench建立数据库,建立新的表,向表中添加数据初学数据库,记录一下所学的知识。我用的MySQL数据库,使用MySQLWorkbench管理。下面简单介绍一下如何使用MySQLWorkbench建立数据库,建立新的表,为表添加数据。  点击上图中的“加号”图标,新建一个连接,    如上图,先输入数据库的账号密码,帐号默认为root,填好密码后点击“OK”,连接就建立好了,建立完成后,会出现一个长方

    2026年3月6日
    5
  • Python:Flask使用jsonify格式化时间

    Python:Flask使用jsonify格式化时间代码如下#-*-coding:utf-8-*-fromdatetimeimportdatetime,datefromflask.jsonimportJSONEncoderclassCustomJSONEncoder(JSONEncoder):defdefault(self,obj):ifisinstance(obj,datetime):returnobj.strftime(‘%Y-%m-%d%H:%M:%

    2022年5月20日
    82
  • Tasklist命令详解

    Tasklist命令详解“Tasklist”命令是一个用来显示运行在本地或远程计算机上的所有进程的命令行工具,带有多个执行参数。

    2022年5月3日
    67

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号