pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/140952.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • python长度单位换算表_长度单位换算表大全

    python长度单位换算表_长度单位换算表大全长度单位换算表大全我国传统的长度单位有里、丈、尺、寸等。1里=150丈=500米。2里=1公里(1000米)1丈=10尺,1尺=10寸。1丈=3.33米,1尺=3.33分米,1寸=3.33厘米。国际单位制中,长度的标准单位是“米”,用符号“m”表示。1960年第十一届国际计量大会:“米的长度等于氪-86原子的2P10和5d1能级之间跃迁的辐射在真空中波长的1650763.73倍”。其他的长度单位还…

    2022年7月11日
    39
  • 固态硬盘有哪些协议知识点?「建议收藏」

    固态硬盘有哪些协议知识点?「建议收藏」固态硬盘的知识点固态硬盘的协议:同品牌,同型号,不同容量速度差距删除数据是真的将数据删除了吗?固态硬盘的协议:硬盘是属于NVM:Non-volatilememory非易失性存储器件。NVM的种类​接口总线协议:​我们固态硬盘一般有两种接口的固态,一种是SATA接口,一种是M.2接口的固态。SATA固态硬盘接口​M.2接口的固态:,这个有两种的接口,一种是2个金手指…

    2026年1月31日
    4
  • 此工作站和主域间的信任关系失败 原因及解决办法[通俗易懂]

    此工作站和主域间的信任关系失败 原因及解决办法[通俗易懂]原因:域控服务器没有客户端的主机名(可能删除了,或重装系统后没添加到域控)处理:在域控上确认客户端主机名是否被禁用,如已禁用,启动即可。转载于:https://www.cnblogs.com/sjdn/p/5923669.html…

    2022年10月19日
    5
  • 教你如何选择以太坊ETH挖矿教程及挖矿分配模式

    教你如何选择以太坊ETH挖矿教程及挖矿分配模式现在,国内国外的矿池越来越多,挖矿难度也越来越大,对于矿工来说,又好又稳定的收益保障一直是追求的目标。以太坊ETH,全球第二大加密货币,对以太坊的追求者都有一种浓重的信仰与情怀,特别是最近消息称以太坊2.0即将来临,更加的激动人心。但最终的目的也是要获得收益,影响收益最大的因素就是矿池收益分配模式,现在矿池收益分配模式有:PPS、PPLNS、PPS+、FPPS等。那么,对于挖以太坊ETH来说…

    2022年6月7日
    66
  • java学的什么软件_java初学者用什么软件[通俗易懂]

    java学的什么软件_java初学者用什么软件[通俗易懂]Java初学者可以使用MyEclipse或eclipse以及记事本。随着学习的深入,相信你会逐渐明白,你会从中找到最合适的开发工具。java初学者使用什么软件Java初学者可以使用MyEclipse、eclipse或记事本。1对于初学者,不建议使用ide开发工具,如eclipse、MyEclipse、intellijidea和netbean。但是,您也可以使用这些。原因不推荐,不方便您了解java…

    2022年7月8日
    26
  • pytorch visdom安装启动问题

    pytorch visdom安装启动问题visdom经过pip安装之后,启动时一直提醒:Checkingforscripts.Downloadingscripts,thismaytakealittlewhile然后即使挂了vpn也下载不下来。。。。网上搜了一堆教程,比较杂乱,记录以下自己简单粗暴的解决方案:C:\Users\zj1996\Anaconda3\envs\pytorch\Lib\site-p…

    2022年6月29日
    47

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号