pytorch – ohem 代码实现

pytorch – ohem 代码实现如果考虑类别和坐标两种情况:importtorchimporttorch.nn.functionalasFimporttorch.nnasnnsmooth_l1_sigma=1.0smooth_l1_loss=nn.SmoothL1Loss(reduction=’none’)#reduce=Falsedefohem_loss(batch_size,…

大家好,又见面了,我是你们的朋友全栈君。

如果考虑类别和坐标两种情况:

import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')    # reduce=False


def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):   
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training      
     loc_pred (FloatTensor): [R, 4], location of positive rois      
     loc_target (FloatTensor): [R, 4], location of positive rois   
     pos_mask (FloatTensor): [R], binary mask for sampled positive rois   
     cls_pred (FloatTensor): [R, C]     
     cls_target (LongTensor): [R]  
     Returns:    
           cls_loss, loc_loss (FloatTensor)
    """

    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    # 这里先暂存下正常的分类loss和回归loss
    print(ohem_cls_loss.shape, ohem_loc_loss.shape)
    loss = ohem_cls_loss + ohem_loc_loss
    # 然后对分类和回归loss求和
    
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)   
    # 再对loss进行降序排列
    
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)    
    # 得到需要保留的loss数量
    
    if keep_num < sorted_ohem_loss.size()[0]:    
        # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
    
        keep_idx_cuda = idx[:keep_num]        # 保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]      
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]        # 分类和回归保留相同的数目
        
    cls_loss = ohem_cls_loss.sum() / keep_num   
    loc_loss = ohem_loc_loss.sum() / keep_num    # 然后分别对分类和回归loss求均值
    return cls_loss, loc_loss


if __name__ == '__main__':
    batch_size = 4
    C = 6
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    cls_pred = torch.randn(8, C)
    cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
    cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
    print(cls_loss, '--', loc_loss)

如果只考虑坐标框的话,对以上代码略微调整如下:

import torch
import torch.nn.functional as F
import torch.nn as nn

smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')  # reduce=False


def ohem_loss(batch_size, loc_pred, loc_target):
    """    Arguments:
     batch_size (int): number of sampled rois for bbox head training
     loc_pred (FloatTensor): [R, 4], location of positive rois
     loc_target (FloatTensor): [R, 4], location of positive rois
     Returns:
           cls_loss, loc_loss (FloatTensor)
    """
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
    loss = ohem_loc_loss  # 对上面代码进行改动,不做简化了,感兴趣的自行替换

    # 再对loss进行降序排列
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)

    # 得到需要保留的loss数量
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)

    # 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留(自己可以改动为保留的比例)
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]  # 保留到需要keep的数目
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]  # 回归保留相同的数目

    loc_loss = ohem_loc_loss.sum() / keep_num  # 然后对回归loss求均值
    return loc_loss


if __name__ == '__main__':
    batch_size = 4
    loc_pred = torch.randn(8, 4)
    loc_target = torch.randn(8, 4)
    loc_loss = ohem_loss(batch_size,loc_pred, loc_target)
    print(loc_loss)

以上代码,新建Python文件,右键运行即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/140952.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • vector>初始化_电脑初始化出现问题

    vector>初始化_电脑初始化出现问题1、默认初始化,vector为空,size为0,未开辟空间,可通过push_back()添加元素。vector&lt;int&gt;v;v.push_back(10);2、默认初始化,指定vector大小,元素初始值默认为0,元素增多时,同样可以通过push_back()来改变vector大小以增加元素。vector&lt;int&gt;v(5)3、指定初始化元素值为2…

    2022年9月18日
    2
  • 无锁编程技术及实现「建议收藏」

    无锁编程技术及实现「建议收藏」1.基于锁的编程的缺点 多线程编程是多CPU系统在中应用最广泛的一种编程方式,在传统的多线程编程中,多线程之间一般用各种锁的机制来保证正确的对共享资源(share resources)进行访问和操作。在多线程编程中只要需要共享某些数据,就应当将对它的访问串行化。比如像++count(count是整型变量)这样的简单操作也得加锁,因为即便是增量操作这样的操作,,实际上也是分三步进行的:读、改、写(回…

    2022年6月10日
    41
  • littlevgl移植_嵌入式ubuntu系统

    littlevgl移植_嵌入式ubuntu系统总述Littlevgl相比较于安卓、QT,占用资源少、使用简单,所以在linux系统下使用Littlevgl优势也比较明显。移植准备工作源码:lvgl:https://github.com/littlevgl/lvgl驱动:lv_drivers:https://github.com/littlevgl/lv_drivers例子:lv_examples:https://github.com/littlevgl/lv_examples下载慢可以将上面链接先导入到码云上再下载。配置工作源码

    2022年9月2日
    3
  • APK签名流程介绍[通俗易懂]

    APK签名流程介绍[通俗易懂]实际上,现在Android开发IDE自带签名功能,但是有时我们还是可能遇到自己签名apk的场景的,比如你有一个未签名的apk,但是你要adbinstall到device上,这时我们在adbinstall之前就必须对该apk进行签名处理才能install成功,这篇文章就简单的介绍下apk签名流程吧。1、生成签名证书签名需要签名证书,签名证书类型实际上是有很多的,如jks、keysto…

    2022年6月13日
    41
  • BaseAdapter导致notifyDataSetChanged()无效的四个原因及处理方法

    BaseAdapter导致notifyDataSetChanged()无效的四个原因及处理方法前一段时间在做一个项目的时候遇到了一个关于BaseAdapter的notifyDataSetChanged()方法无效问题,当时在网上搜了一个解决方法,今天又遇到了一个类似的问题,我在这里做个记录,防止以后再次发生,或者其他朋友再次遇到。一、ScrollView中嵌套ListView或GridView原因:两个的滚动监听冲突解决方法:重写ListView或GridViewpackagecom.m

    2022年6月18日
    24
  • 数据库:SQLServer 实现行转列、列转行用法笔记

    数据库:SQLServer 实现行转列、列转行用法笔记

    2020年11月14日
    327

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号