TensorFlow中学习率[通俗易懂]

TensorFlow中学习率[通俗易懂]学习率学习率属于超参数。学习率决定梯度下降速度的快慢,学习率越大,速度越快;学习率越小,速度越慢。如果学习率过大,很可能会越过最优值;反而如果学习率过小,优化的效率可能过低,长时间算法无法收敛。所以学习率对于算法性能的表现至关重要。指数衰减学习率指数衰减学习率是在学习率的基础上增加了动态变化的机制,会随着梯度下降变化而动态变化tf.train.expo…

大家好,又见面了,我是你们的朋友全栈君。

学习

学习率属于超参数。学习率决定梯度下降速度的快慢,学习率越大,速度越快;学习率越小,速度越慢。如果学习率过大,很可能会越过最优值;反而如果学习率过小,优化的效率可能过低,长时间算法无法收敛。所以学习率对于算法性能的表现至关重要。

 

 

 

指数衰减学习率

 

指数衰减学习率是在学习率的基础上增加了动态变化的机制,会随着梯度下降变化而动态变化

 

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)

 

  • learn_rate:事先设定的初始学习率
  • global_step:训练轮数
  • decay_steps:衰减速度。staircase=True:代表了完整的使用一遍训练数据所需要的迭代轮数(=总训练样本数/每个batch中的训练样本数)
  • decay_rate:衰减系数
  • staircase:默认为False,此时学习率随迭代轮数的变化是连续的(指数函数);为 True 时,global_step/decay_steps 会转化为整数,此时学习率便是阶梯函数

步骤:

  1. 首先使用较大学习率(目的:为快速得到一个比较优的解);
  2. 然后通过迭代逐步减小学习率(目的:为使模型在训练后期更加稳定);

模板:


global_step = tf.Variable(0)

 

learning_rate = tf.train.exponential_decay(0.1, global_step, 1, 0.96, staircase=True)     #生成学习率

 

learning_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(....., global_step=global_step)  #使用指数衰减学习率

 

实例代码:

TRAINING_STEPS = 100
global_step = tf.Variable(0)
LEARNING_RATE = tf.train.exponential_decay(
    0.1, global_step, 1, 0.96, staircase=True)

x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
y = tf.square(x)
train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(
    y, global_step=global_step)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(TRAINING_STEPS):
        sess.run(train_op)
        if i % 10 == 0:
            LEARNING_RATE_value = sess.run(LEARNING_RATE)
            x_value = sess.run(x)
            print("After %s iteration(s): x%s is %f, learning rate is %f." %
                  (i + 1, i + 1, x_value, LEARNING_RATE_value))

 

关于global_step的探究:

  • global_step – 用于衰减计算的全局步骤。 一定不为负数。
  • 喂入一次 BACTH_SIZE 计为一次 global_step
  • 每间隔decay_steps次更新一次learning_rate值

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/137629.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • json字符串转map对象_java string 转jsonobject

    json字符串转map对象_java string 转jsonobjectMap转JSON字符串、String转JSONObject、JSONObject转JSON数组1.将Map转成JSON字符串:JSONObject.toJSONString();//请求参数Map<String,Object>paramsMap=newHashMap<>();paramsMap.put(“startDate”,”2021-04-01″);paramsMap.put(“endDate”,”2021-04-13″);//将请求参数

    2022年10月4日
    1
  • linux查看mysql用户权限_教您如何查看MySQL用户权限

    linux查看mysql用户权限_教您如何查看MySQL用户权限教您如何查看MySQL用户权限如果需要查看MySQL用户权限,应该如何实现呢?下面就为您介绍查看MySQL用户权限的方法,并对授予MySQL用户权限的语句进行介绍,供您参考。查看MySQL用户权限:showgrantsfor你的用户比如:showgrantsforroot@’localhost’;Grant用法GRANTUSAGEON*.*TO’discuz’@’local…

    2022年6月18日
    151
  • 1.巴特沃斯模拟滤波器(低通,高通,带通,带阻)设计-MATLAB实现

    1.巴特沃斯模拟滤波器(低通,高通,带通,带阻)设计-MATLAB实现1.基础知识介绍我们首先明确一个知识(这个非常重要):某正弦信号,频率为50Hz这意味着信号的模拟频率fff=50(Hz),注意它的单位是Hz信号的表达式为y=sin(2πft)=sin(2π∗50t)=sin(100πt)y=sin(2\pift)=sin(2\pi*50t)=sin(100\pit)y=sin(2πft)=sin(2π∗50t)=sin(100πt)由于信号也可以表示为y=sin(Ωt)y=sin(\Omegat)y=sin(Ωt)的形式,所以这里

    2022年5月16日
    767
  • Qt —— QWebEngineView加载谷歌离线地图(包含离线地图瓦片下载制作)

    Qt —— QWebEngineView加载谷歌离线地图(包含离线地图瓦片下载制作) 关注微信公众号搜索”Qt_io_”或”Qt开发者中心”了解更多关于Qt、C++开发知识.。笔者-jxd

    2022年9月20日
    3
  • 软件测试流程详解「建议收藏」

    软件测试流程详解「建议收藏」1.软件测试的定义:使用人工或自动手段,来运行或测试某个系统的过程。其目的在于检验它是否满足规定的需求或弄清预期结果与实际结果之间的差别。百度百科定义:软件测试(英语:SoftwareTesting),描述一种用来促进鉴定软件的正确性、完整性、安全性和质量的过程。换句话说,软件测试是一种实际输出与预期输出间的审核或者比较过程。软件测试的经典定义是:在规定的条件下对程序进行操作,以发现程序错…

    2022年6月7日
    24
  • adventureworksdw2008r2_数据库表例子

    adventureworksdw2008r2_数据库表例子从观望

    2025年10月30日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号