在线难例挖掘(OHEM)[通俗易懂]

在线难例挖掘(OHEM)[通俗易懂]OHEM(onlinehardexampleminiing)详细解读一下OHEM的实现代码:defohem_loss(batch_size,cls_pred,cls_target,loc_pred,loc_target,smooth_l1_sigma=1.0):”””Arguments:batch_size(int):…

大家好,又见面了,我是你们的朋友全栈君。

OHEM(online hard example miniing)

详细解读一下OHEM的实现代码:

def ohem_loss(
    batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
    """
    Arguments:
        batch_size (int): number of sampled rois for bbox head training
        loc_pred (FloatTensor): [R, 4], location of positive rois
        loc_target (FloatTensor): [R, 4], location of positive rois
        pos_mask (FloatTensor): [R], binary mask for sampled positive rois
        cls_pred (FloatTensor): [R, C]
        cls_target (LongTensor): [R]

    Returns:
        cls_loss, loc_loss (FloatTensor)
    """
    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
    #这里先暂存下正常的分类loss和回归loss
    loss = ohem_cls_loss + ohem_loc_loss
    #然后对分类和回归loss求和

  
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)
    #再对loss进行降序排列
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)
    #得到需要保留的loss数量
    if keep_num < sorted_ohem_loss.size()[0]:
    #这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
        keep_idx_cuda = idx[:keep_num]
        #保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
        #分类和回归保留相同的数目
    cls_loss = ohem_cls_loss.sum() / keep_num
    loc_loss = ohem_loc_loss.sum() / keep_num
    #然后分别对分类和回归loss求均值
    return cls_loss, loc_loss

为什么要叫在线难例最小化呢?

因为在深度学习提出这个方法的人,想和传统方法区分开。难例挖掘,机器学习学习中尤其是在svm中早就已经使用,又称为bootstrapping。

传统的难例挖掘流程:首先是通过训练集训练网络,训练完成,然后固定网络,寻找新的样本,加入到训练集中。很显然这将耗费很长的时间。

因此作者提出的是在线难例挖掘。

 

具体怎么实现的呢?

在线难例挖掘(OHEM)[通俗易懂]

通常是搬出这张图,说实话这张图有点啰嗦!

按我的理解,OHEM的操作就是舍弃了faster RCNN中的正负样本(ROI)比例为1:3,它通过每个ROI的loss值,对所有roi的loss排序,取B/N数量的roi组成mini batch。注意:对于指向同一个目标的rois,通过NMS,取loss最大的roi,其他都删除。

也就是通过loss提高难样本的比例,让网络花更多精力去学习难样本。

 

我觉得它和focal loss思路本质是一样的,focal loss把loss作用在类别上,二目标检测OHEM把loss 作用在ROI上。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139244.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月30日 上午8:00
下一篇 2022年5月30日 上午8:16


相关推荐

  • python 安装第三方包-安装失败(pycharm/ anaconda navigator)

    python 安装第三方包-安装失败(pycharm/ anaconda navigator)安装/卸载第三包:方法一:pipinstall包名;pipuninstall包名。方法二:pipinstall下载路径\包名.whl(需要先下载第三包:地址:https://www.lfd.uci.edu/~gohlke/pythonlibs/,找到所需的包并下载保存),如下例安装gensim包所示:方法三:若在pycharm编辑中,则在菜单setting/project…

    2022年8月25日
    6
  • React 生命周期详细解析及新旧对比

    React 生命周期详细解析及新旧对比React 生命周期新旧对比

    2026年3月19日
    2
  • 空间变换是什么_信号与系统状态转移矩阵

    空间变换是什么_信号与系统状态转移矩阵出自论文SpatialTransformerNetworksInsight:文章提出的STN的作用类似于传统的矫正的作用。比如人脸识别中,需要先对检测的图片进行关键点检测,然后使用关键点来进行对齐操作。但是这样的一个过程是需要额外进行处理的。但是有了STN后,检测完的人脸,直接就可以做对齐操作。关键的一点就是这个矫正过程是可以进行梯度传导的。想象一下,人脸检测完了,直接使用R

    2022年10月19日
    5
  • java编译原理

    java编译原理4.Java编译原理1.javac是什么?(1)javac是一种编译器,能够将一种语言规范转换成另一种用语言规范,通常编译器是将便于人们理解的语言规范成机器容易理解的语言规范。(2)javac的任务就是将java源代码语言转换成jvm能够识别的语言,然后jvm将jvm语言再转化成当前机器能够识别的语言(这样使得对开发者屏蔽与机器相关的细节,并且使得语言的执行与平台无关)2.javac编译器的基本结…

    2022年5月9日
    39
  • Tab功能的使用

    Tab功能的使用

    2026年3月16日
    2
  • 第一章 热传导方程

    第一章 热传导方程第一章热传导方程 目录如下 1 推导一维杆的热传导方程 从微分及积分角度分别进行了推导 2 初值和边界条件 初值是与时间相关 边值与空间相关 3 二维及三维热传导方程推导 从积分角度推导 得到泊松方程和拉普拉斯方程 4 拉普拉斯算子的各种形式 在直角坐标系 柱坐标系和球坐标系下推导拉普拉斯算子形式 nbsp 偏微分方程 PDE 就是指含有偏导数的数学方程 本书从

    2026年3月20日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号