NativeScaler()与loss_scaler

NativeScaler()与loss_scaler源码:classNativeScaler:state_dict_key=”amp_scaler”def__init__(self):self._scaler=torch.cuda.amp.GradScaler()def__call__(self,loss,optimizer,clip_grad=None,clip_mode=’norm’,parameters=None,create_graph=False):

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全家桶1年46,售后保障稳定

源码:

class NativeScaler:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if clip_grad is not None:
            assert parameters is not None
            self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
        self._scaler.step(optimizer)
        self._scaler.update()

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)

Jetbrains全家桶1年46,售后保障稳定

loss_scaler 函数,它的作用本质上是 loss.backward(create_graph=create_graph) 和 optimizer.step()。

loss_scaler 继承 NativeScaler 这个类。这个类的实例在调用时需要传入 loss, optimizer, clip_grad, parameters, create_graph 等参数,在 __call__ () 函数的内部实现了 loss.backward(create_graph=create_graph) 功能和 optimizer.step() 功能。

例子使用:

from timm.utils import NativeScaler

loss_scaler = NativeScaler()
loss_scaler(loss_G, optimizer, parameters=model_restoration.parameters())

代码等价: 

loss_G.backward()
optimizer.step()
等价于下面的代码:
loss_scaler(loss_G, optimizer, parameters=model_restoration.parameters())
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/215508.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Fedora 12 环境搭建[通俗易懂]

    Fedora 12 环境搭建[通俗易懂]又来折腾发行版了。这一回是Fedora12,搞的挺艰难的下载了Fedora-12-i386-DVD.iso,无论使用ultraiso还是dd都无法安装。后来下载了一个ImageWriter.exe(o

    2022年7月3日
    29
  • CLLocation定位

    CLLocation定位importUIKitimportCoreLocationimportAlamofiretypealiasLocationClosure=((_sheng:String,_shi:String,_qu:String)->Void)classCLLocationTool:NSObject{publicstaticlet`default`=CLLocationTool.init()///定…

    2022年7月26日
    3
  • Redis-03Redis数据结构–全局命令及字符串string

    Redis-03Redis数据结构–全局命令及字符串string在了解具体的数据结构类型之前,我们有必要了解下Redis提供的操作key的全局命令、数据结构和内部编码、单线程命令处理机制,都有助于加深对Redis的理解。

    2022年10月3日
    2
  • gb50174-2017电子信息系统机房设计规范发布时间_计算机机房设计标准GB50174

    gb50174-2017电子信息系统机房设计规范发布时间_计算机机房设计标准GB50174中国工程建设标准化协会信息通信专业委员会建标信通字[2009]03号GB50174《电子信息系统机房设计规范》贯标培训通…

    2022年9月28日
    7
  • 笔试面试算法经典–最长回文子串

    笔试面试算法经典–最长回文子串回文的定义正读和反读都相同的字符序列为“回文”,如“abba”、“abccba”是“回文”,“abcde”和“ababab”则不是“回文”。字符串的最长回文子串,是指一个字符串中包含的最长的回文子串。例如“1212134”的最长回文子串是“12121”。下面给出了三种求最长子串的方法。解法1(中心扩展法)时间复杂度O(n^2),空间复杂度为O(1)。中心扩展法的思路是,遍历到数组的某一个元素时,以这

    2022年6月9日
    78
  • dstat使用[通俗易懂]

    dstat使用[通俗易懂]1、安装方法一:yum#yuminstall-ydstat方法二:rpm官网下载地址:http://dag.wieers.com/rpm/packages/dstat #wget http://dag.wieers.com/rpm/packages/dstat/dstat-0.6.7-1.rh7.rf.noarch.rpm#rp

    2022年6月15日
    44

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号