pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/140952.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 使用栈实现表达式求值

    使用栈实现表达式求值任何一个表达式都是由操作数,运算符,界限符组成的。操作数即是参加运算的数值或者变量,运算符则是加减乘除等组成,为简单起见,这里只实现加减乘除的运算,而常见的界限符则是左右括号和终止符。在运算过程中,要判断两个先后出现的运算符之间的优先顺序。为了实现算法,设置两个工作栈:用于存储运算符的栈opter,以及用于存储操作数及中间结果的栈opval。算法基本思想如下:(1)首先将操作数栈opv

    2022年6月26日
    28
  • metasploit指令_msfconsole下载

    metasploit指令_msfconsole下载在MSF里面msfconsole可以说是最流行的一个接口程序。很多人一开始碰到msfconsole的时候就害怕了。那么多复杂的命令语句需要学习,但是msfconsole真的是一个强大的接口程序。Msfconsole提供了一个一体化的集中控制台。通过msfconsole,你可以访问和使用所有的metasploit的插件,payload,利用模块,post模块等等。Msfconsole还有第三方程序的…

    2022年9月7日
    0
  • mysql的where条件后加case_recommend

    mysql的where条件后加case_recommend背景:数据库用的Oracle;报表用的是【FineReport】,之前没用过,被临时授命解决问题,所以大概了解了一下。里面应该是集成了excel插件,报表样式如下:今天在项目中遇到一个这样的场景:A为汇总页面,显示的是按医院分组统计出来的一些数据,效果如下图图中每一列都能下钻到另一个页面,医院名称和起始时间都作为参数传送。前期因为某一些需求,有一家医院出现了两个不同的名…

    2022年9月4日
    3
  • Python str join方法:拼接字符串「建议收藏」

    Python str join方法:拼接字符串「建议收藏」Python字符串方法join()介绍、使用和注意事项。

    2022年4月29日
    46
  • POJ2186 Popular Cows 【强连通分量】+【Kosaraju】+【Tarjan】+【Garbow】

    POJ2186 Popular Cows 【强连通分量】+【Kosaraju】+【Tarjan】+【Garbow】

    2022年1月29日
    38
  • 数据滤波算法集合「建议收藏」

    数据滤波算法集合「建议收藏」由于要进行数据处理,就利用网络资源总结各种滤波方法以便日后查阅。一、限幅滤波法实现步骤:根据经验法选择最大偏差值E。|value_now-value_before|&amp;amp;amp;amp;amp;amp;amp;amp;amp;amp;amp;lt;=E,value_now有效,否则其无效且将其舍弃,最后令value_now=value_before。实现程序:#defineE10//value取值范围为90~110intv

    2022年5月3日
    103

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号