pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/140952.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 死亡Error:OSError: [Errno 12] Cannot allocate memory

    死亡Error:OSError: [Errno 12] Cannot allocate memory死亡Error:OSError:[Errno12]Cannotallocatememory调试背景:使用的是github上https://github.com/arunmallya/packnet这里的代码。调试的时候,出现Error,如下:main()File”main.py”,line378,inmainmanager.prune()Fi…

    2022年6月24日
    120
  • Sublime Text3 如何安装、删除及更新插件

    Sublime Text3 如何安装、删除及更新插件1、打开SublimeText3,按Ctrl+`(和qq输入法快捷切换冲突,可以修改qq的输入法切换热键)2、复制粘黏以下代码添加至命令行,然后回车(功能:安装插件的工具,有了它,以后安装其他插件更方便)importurllib.request,os;pf=’PackageControl.sublime-package’;ipp=sublime.inst…

    2022年7月11日
    22
  • vue 集成高德地图进行批量标注和信息窗体展示

    vue 集成高德地图进行批量标注和信息窗体展示                                   vue集成高德地图进行批量标注和信息窗体展示         高德地图进行地理位置的标注和信息窗体展示是我们很常用的一个功能,其实高德api里面已经清楚的说明怎么用了,但是自己总结一下记录在自己的笔记里,也是有些许好处的。高德api样列展示地址是:https://lbs.amap.com/api/javascript…

    2022年5月21日
    66
  • mysql语句和sql语句的区别_oracle和sqlserver的语法区别

    mysql语句和sql语句的区别_oracle和sqlserver的语法区别sql和mysql语法的区别有:mysql支持enum和set类型,sql不支持,mysql需要为表指定存储类型,mysqlL中text字段类型不允许有默认值,sql允许有等等方面都存在差异MySQL与SQLServer的语法区别1、MySQL支持enum,和set类型,SQLServer不支持2、MySQL不支持nchar,nvarchar,ntext类型3、MySQL的递增语句是AUTO_I…

    2022年10月2日
    2
  • 小虾的sql server 2000 成长之路

    小虾的sql server 2000 成长之路

    2021年7月29日
    23
  • redis 过期删除策略(redis过期机制)

    过期删除策略redis可以对key的通用设置中,可以设置key的过期时间及ttl如果单纯的再client中进行命令测试的话,会发现了当时间到时间后再去获取该key会显示nil那么一个key过期了,那么它实际是在什么时候删除的呢?当然这个删除也不是简单的到期了就直接被删除了redis中对于过期键的过期删除策略定时删除惰性删除定期删除定时删除它会在设置键的过期时间的同时,创建一个定时器,当键到了过期时间,定时器会立即对键进行删除。这个策略能够保证过期键的尽快删除,快速释放内存空间

    2022年4月10日
    98

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号