pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/140952.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Emacs中的文件比较与合并工具–Ediff

    Emacs中的文件比较与合并工具–Ediff

    2021年8月27日
    133
  • linux目录结构详解_linux目录的结构及含义

    linux目录结构详解_linux目录的结构及含义前言平常linux系统用的也不少,那么linux下的每个目录都是用来干什么的,小伙伴们有仔细研究过吗?让我们来了解下吧Linux系统目录结构登录系统后,在当前命令窗口下输入命令:[root@

    2022年7月31日
    6
  • aliddns ipv6_Python实现阿里云域名DDNS支持ipv4和ipv6-阿里云开发者社区

    aliddns ipv6_Python实现阿里云域名DDNS支持ipv4和ipv6-阿里云开发者社区前言然后你的IP必须是公网IP,不然解析了也没用。本文章讲怎样通过阿里云的SDK来添加修改域名解析,检查本机IP与解析的IP是否一致,不一致自动修改解析,达到动态解析的目的,主要用于家庭宽带这些动态IP的地方。安装阿里云SDK和其他第三方库pipinstallaliyun-python-sdk-core-v3pipinstallaliyun-python-sdk-domainpipins…

    2022年5月30日
    34
  • 数据库的范式(第一范式,第二范式,第三范式,BCNF范式)「建议收藏」

    在了解范式之前我们先了解下数据库中关于码的概念1.码1.超码能够唯一标识元组的某一属性或属性组,任何包含超码的超集也是超码,这里唯一标识元组可以简单的理解为根据某一个字段或几个字段的值,查询出某一行特定的数据2.候选码从超码中选出的最小的码,即其任何真子集都不能满足条件。即属性不可再删减。3.主码从候选码中选出一个作为主码。2.范式(NF)范式:符合某一种级别的关系模式的集合,简…

    2022年4月16日
    62
  • 论文写作利器—LaTeX教程(入门篇)(更新中)

    论文写作利器—LaTeX教程(入门篇)(更新中)一、LaTeX简介结合维基百科及LaTeX官网可知:LaTeX(/ˈlɑːtɛx/,常被读作/ˈlɑːtɛk/或/ˈleɪtɛk/)是一种基于TeX的高品质排版系统,由美国计算机科学家莱斯利·兰伯特在20世纪80年代初期开发,非常适用于生成高印刷质量的科技和数学、物理文档,尤其擅长于复杂表格和数学公式的排版。LaTeX是科学文献交流和出版的事实标准。简单来说,相比于Word排版时需要设…

    2022年7月14日
    44
  • 什么是SOA架构?为什么使用SOA架构?

    什么是SOA架构?为什么使用SOA架构?SOA架构简介面向服务的架构(SOA)是一个组件模型,它将应用程序的不同功能单元(称为服务)进行拆分,通过这些服务之间定义良好的接口和契约联系起来。接口是采用中立的方式进行定义的,它应该独立于实现服务的硬件平台、操作系统和编程语言。这使得构建在各种这样的系统中的服务可以以一种统一和通用的方式进行交互**SOA具有以下五个特征**1.可重用;2.松耦合;3.明确定义的接口;…

    2022年6月24日
    29

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号