NativeScaler()与loss_scaler

NativeScaler()与loss_scaler源码:classNativeScaler:state_dict_key=”amp_scaler”def__init__(self):self._scaler=torch.cuda.amp.GradScaler()def__call__(self,loss,optimizer,clip_grad=None,clip_mode=’norm’,parameters=None,create_graph=False):

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全家桶1年46,售后保障稳定

源码:

class NativeScaler:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if clip_grad is not None:
            assert parameters is not None
            self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
        self._scaler.step(optimizer)
        self._scaler.update()

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)

Jetbrains全家桶1年46,售后保障稳定

loss_scaler 函数,它的作用本质上是 loss.backward(create_graph=create_graph) 和 optimizer.step()。

loss_scaler 继承 NativeScaler 这个类。这个类的实例在调用时需要传入 loss, optimizer, clip_grad, parameters, create_graph 等参数,在 __call__ () 函数的内部实现了 loss.backward(create_graph=create_graph) 功能和 optimizer.step() 功能。

例子使用:

from timm.utils import NativeScaler

loss_scaler = NativeScaler()
loss_scaler(loss_G, optimizer, parameters=model_restoration.parameters())

代码等价: 

loss_G.backward()
optimizer.step()
等价于下面的代码:
loss_scaler(loss_G, optimizer, parameters=model_restoration.parameters())
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/215508.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • php环境安装与配置_windows下php环境搭建

    php环境安装与配置_windows下php环境搭建配置时区打开php解压目录,找到php.ini-development文件,将其改名为php.ini,用记事本打开。找到(带分号);date.timezone=去掉前面的分好,修改为date.timezone=Asia/Shanghai测试:在根目录下的index.php文件夹中写入以下代码<?phpechodate(“Y:m:dH:i:s”…

    2022年9月22日
    3
  • docker dockerfile详解_进入docker容器命令

    docker dockerfile详解_进入docker容器命令前言Dockerfile是一个用来构建镜像的文本文件,文本内容包含了一条条构建镜像所需的指令和说明。Dockerfile简介Dockerfile是用来构建Docker镜像的构建文件,是由一系列

    2022年7月29日
    4
  • 粒子群优化算法matlab程序_多目标优化算法

    粒子群优化算法matlab程序_多目标优化算法1.粒子群优化算法概述2.粒子群优化算法求解     2.1连续解空间问题     2.2构成要素     2.3算法过程描述     2.4粒子速度更新公式     2.5速度更新参数分析3.粒子群优化算法小结4.MATLAB

    2022年10月11日
    4
  • CentOS 7如何配置yum源「建议收藏」

    CentOS 7如何配置yum源「建议收藏」相关说明:      本教程主要讲解配置“本地yum源”、“网络yum源”以及“ELEP源”yum简介:     1.Yum(全称为YellowdogUpdater,Modified)是一个在Fedora和RedHat以及CentOS中的Shell前端软件包管理器。        2.基于RPM包管理,能够从指定的服务器自动下载RPM包并且安装,可以自动处理依赖性关系,并且一次…

    2022年8月13日
    4
  • Python环境配置及项目建立

    Python环境配置及项目建立一、安装PythonPython比较稳定的两个版本是Python3.5和Python2.7,我用的是Python2.7,下载地址是:https://www.python.org/downloa

    2022年7月5日
    21
  • 移位寄存器专题(verilog HDL设计)

    移位寄存器专题(verilog HDL设计)目录移位寄存器简介分类4位右移位寄存器工作原理1、16位右移位寄存器2、16位左移寄存器3、串行输入并行输出寄存器4、并行输入串行输出移位寄存器移位寄存器简介移位寄存器内的数据可以在移位脉冲(时钟信号)的作用下依次左移或右移。移位寄存器不仅可以存储数据,还可以用来实现数据的串并转换、分频,构成序列码发生器、序列码检测器,进行数值运算以及数据处理等,它也…

    2022年7月16日
    13

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号