OHEM介绍

OHEM介绍在two-stage检测算法中,RPN阶段会生成大量的检测框,由于很多时候一张图片可能只会有少量几个目标,也就是说绝大部分框是没有目标的,为了减少计算就需要进行sample,一般来说fasterrcnn的sample机制是算框和label的IOU,大于0.7认为是正样本,小于0.3是负样本。但是单纯的random_sample选出来的框不一定是最容易错的框。那么ohem就是较好的一种正负样本策略

大家好,又见面了,我是你们的朋友全栈君。

目标检测之OHEM介绍

论文地址:https://arxiv.org/pdf/1604.03540.pdf

在two-stage检测算法中,RPN阶段会生成大量的检测框,由于很多时候一张图片可能只会有少量几个目标,也就是说绝大部分框是没有目标的,为了减少计算就需要进行sample,一般来说fasterrcnn的sample机制是算框和label的IOU,大于0.7认为是正样本,小于0.3是负样本。但是单纯的random_sample选出来的框不一定是最容易错的框。那么ohem就是这样的一种正负样本策略,通过根据框的loss得到最容易错的框。可以理解为错题集,我们只会把最容易错的题放到错题集。

首先是 negative,即负样本,其次是 hard,说明是困难样本,也可以说是容易将负样本看成正样本的那些样本,例如 RPN框里没有物体,全是背景,这时候分类器很容易正确分类成背景,这个就叫 easy negative;如果 框里有二分之一个物体,标签仍是负样本,这时候分类器就容易把他看成正样本,这时候就是 had negative。hard negative mining 就是多找一些 hard negative 加入负样本集,进行训练。
接下来我们来看看mmdection的ohem实现:

class OHEMSampler(BaseSampler):
    r"""Online Hard Example Mining Sampler described in `Training Region-based
    Object Detectors with Online Hard Example Mining
    <https://arxiv.org/abs/1604.03540>`_.
    """

    def __init__(self,
                 num,
                 pos_fraction,
                 context,
                 neg_pos_ub=-1,
                 add_gt_as_proposals=True,
                 **kwargs):
        super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
                                          add_gt_as_proposals)
        self.context = context
        if not hasattr(self.context, 'num_stages'):
            self.bbox_head = self.context.bbox_head
        else:
            self.bbox_head = self.context.bbox_head[self.context.current_stage]

    def hard_mining(self, inds, num_expected, bboxes, labels, feats):
        with torch.no_grad():
            rois = bbox2roi([bboxes])
            if not hasattr(self.context, 'num_stages'):
                bbox_results = self.context._bbox_forward(feats, rois)
            else:
                bbox_results = self.context._bbox_forward(
                    self.context.current_stage, feats, rois)
            cls_score = bbox_results['cls_score']
            loss = self.bbox_head.loss(
                cls_score=cls_score,
                bbox_pred=None,
                rois=rois,
                labels=labels,
                label_weights=cls_score.new_ones(cls_score.size(0)),
                bbox_targets=None,
                bbox_weights=None,
                reduction_override='none')['loss_cls']
            _, topk_loss_inds = loss.topk(num_expected)
        return inds[topk_loss_inds]

    def _sample_pos(self,
                    assign_result,
                    num_expected,
                    bboxes=None,
                    feats=None,
                    **kwargs):
        """Sample positive boxes.

        Args:
            assign_result (:obj:`AssignResult`): Assigned results
            num_expected (int): Number of expected positive samples
            bboxes (torch.Tensor, optional): Boxes. Defaults to None.
            feats (list[torch.Tensor], optional): Multi-level features.
                Defaults to None.

        Returns:
            torch.Tensor: Indices  of positive samples
        """
        # Sample some hard positive samples
        pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
        if pos_inds.numel() != 0:
            pos_inds = pos_inds.squeeze(1)
        if pos_inds.numel() <= num_expected:
            return pos_inds
        else:
            return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
                                    assign_result.labels[pos_inds], feats)

    def _sample_neg(self,
                    assign_result,
                    num_expected,
                    bboxes=None,
                    feats=None,
                    **kwargs):
        """Sample negative boxes.

        Args:
            assign_result (:obj:`AssignResult`): Assigned results
            num_expected (int): Number of expected negative samples
            bboxes (torch.Tensor, optional): Boxes. Defaults to None.
            feats (list[torch.Tensor], optional): Multi-level features.
                Defaults to None.

        Returns:
            torch.Tensor: Indices  of negative samples
        """
        # Sample some hard negative samples
        neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
        if neg_inds.numel() != 0:
            neg_inds = neg_inds.squeeze(1)
        if len(neg_inds) <= num_expected:
            return neg_inds
        else:
            neg_labels = assign_result.labels.new_empty(
                neg_inds.size(0)).fill_(self.bbox_head.num_classes)
            return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
                                    neg_labels, feats)

上面代码就是整个ohem的sample过程,整个ohem分为三个函数分别是hard_mining,_sample_pos,_sample_neg,_sample_pos和_sample_neg是获得对应的困难正样本/困难负样本,由hard_mining完成整个sample过程:根据输入的box_list得到对应的bbox_loss的list取最大的256/512个,由于这一批box的loss最大,就可以认为是最难区分的box,这一批bbox就是所谓的
困难正样本/困难负样本。

至此ohem阶段完成,后面就是对候选框的分类和回归,因为ohem阶段得到了容易分错的样本框,所以在后续训练阶段模型会对这些容易分错的框重点关注,有利于困难样本的检测,提升了模型的效果。

实际上提升还是很明显的:
在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139162.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 纸张与什么是使用喷墨打印机所需的消耗品(打印机打印出来的纸黑乎乎的)

    【PConline杂谈】一直潜心研究打印机的小编,由于长时间周旋于各种打印机,甚是无聊。因此近日研究了点特别的东西。关于打印机耗材方面,一般都是硒鼓、墨盒等。对于打印机要用量最大的纸张耗材,却鲜有人关注。于是,小编就趁着元旦假期去恶补了下相关知识。对于常用的纸张耗材,给人的感觉却是熟悉而又陌生的。因此,因此这篇文章就谈一谈关于纸的知识。纸,四大古代发明之一,在人们的日常生活中发挥着及其重要的作用…

    2022年4月11日
    106
  • IoC控制反转是什么意思?[通俗易懂]

    IoC控制反转是什么意思?[通俗易懂]最近由于日本项目的需要,开始学习Spring框架的东西。虽然框架被日方公司进行了一定的修改,但Spring大体原理是不变的。Spring最大的特点,相信大家在网上看了许多,都知道是控制反转(IOC),或者叫依赖注入(DI),那么究竟什么是控制反转,什么是依赖注入呢?IOC(inversionofcontrol)控制反转模式;控制反转是将组件间的依赖关系从程序内部提到外部来管理; DI(depe…

    2022年6月29日
    24
  • python运行pyc文件_Python pyc文件[通俗易懂]

    python运行pyc文件_Python pyc文件[通俗易懂]什么是pyc文件pyc是由py文件经过编译后二进制文件,py文件变成pyc文件后,加载的速度有所提高,而且pyc是一种跨平台的字节码,是由python的虚拟机来执行的。pyc的内容,是跟python的版本相关的,不同版本编译后的pyc文件是不同的,2.5编译的pyc文件,2.4版本的python是无法执行的。pyc文件也是可以反编译的,不同版本编译后的pyc文件是不同。为什么需要pyc文件…

    2022年6月16日
    35
  • RSA 加密算法原理简述

    RSA 加密算法原理简述概述本文旨在说明RSA加密算法的原理及实现,而其相关的数学部分的证明则不是本文内容。版权说明著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。作者:Q-WHai发表日期:2016年2月29日本文链接:http://blog.csdn.net/lemon_tree12138/article/details/50696926来源:CSDN…

    2022年6月12日
    25
  • apk反编译后smali文件_apk反编译失败

    apk反编译后smali文件_apk反编译失败APK反编译之一:基础知识原文作者:lpohvbe| http://blog.csdn.net/lpohvbe/article/details/7981386APK、Dalvik字节码和smali文件APK文件  大家都应该知道APK文件其实就是一个MIME为ZIP的压缩包,我们修改ZIP后缀名方式可以看到内部的文件结构,例如修改后缀后用RAR

    2022年9月2日
    4
  • 小议AutoEventWireup属性

    小议AutoEventWireup属性1.在web页面添加一个label和button控件ViewCode<%@PageLanguage=”C#”AutoEventWireup=”false”CodeFile=”AutoEventWireup属性.aspx.cs”Inherits=”_Default”%><!DOCTYPEhtmlPUBLIC”-//W3C//DT…

    2022年5月28日
    30

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号