OHEM代码梳理[通俗易懂]

OHEM代码梳理[通俗易懂]传送门:相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结代码地址:OHEM1.前言有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得…

大家好,又见面了,我是你们的朋友全栈君。

  • 传送门
  1. 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
  2. 代码地址:OHEM

1. 前言

有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。

layer {
  name: "hard_roi_mining"
  type: "Python"
  bottom: "cls_prob_readonly"
  bottom: "bbox_pred_readonly"
  bottom: "rois"
  bottom: "labels"
  bottom: "bbox_targets"
  bottom: "bbox_inside_weights"
  bottom: "bbox_outside_weights"
  top: "rois_hard"
  top: "labels_hard"
  top: "bbox_targets_hard"
  top: "bbox_inside_weights_hard"
  top: "bbox_outside_weights_hard"
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  python_param {
    module: "roi_data_layer.layer"
    layer: "OHEMDataLayer"
    param_str: "'num_classes': 6" #6
  }
}

2. OHEM代码简单梳理

2.1 OHEMDataLayer

class OHEMDataLayer(caffe.Layer):
    """Online Hard-example Mining Layer."""
    def setup(self, bottom, top):
        """Setup the OHEMDataLayer."""

        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)

        self._num_classes = layer_params['num_classes']	 # 获取分类数目

        self._name_to_bottom_map = { 
     # 将bottom的blob名称与index使用dict关联
            'cls_prob_readonly': 0,
            'bbox_pred_readonly': 1,
            'rois': 2,
            'labels': 3}

        if cfg.TRAIN.BBOX_REG:  # 有边界框回归
            self._name_to_bottom_map['bbox_targets'] = 4
            self._name_to_bottom_map['bbox_loss_weights'] = 5

        self._name_to_top_map = { 
   }  # 同理top的blob名称也要与index关联起来
        ……
	
	# 前向传播函数
    def forward(self, bottom, top):
        """Compute loss, select RoIs using OHEM. Use RoIs to get blobs and copy them into this layer's top blob vector."""

        cls_prob = bottom[0].data  # 获取对应bottom的数据
        bbox_pred = bottom[1].data
        rois = bottom[2].data
        labels = bottom[3].data
        if cfg.TRAIN.BBOX_REG:
            bbox_target = bottom[4].data
            bbox_inside_weights = bottom[5].data
            bbox_outside_weights = bottom[6].data
        else:
            bbox_target = None
            bbox_inside_weights = None
            bbox_outside_weights = None

        flt_min = np.finfo(float).eps
        # 计算分类的损失
        loss = [ -1 * np.log(max(x, flt_min)) \
            for x in [cls_prob[i,label] for i, label in enumerate(labels)]]

        # 计算边界框回归的损失,并且将两个损失加起来
        if cfg.TRAIN.BBOX_REG:
            # bounding-box regression loss
            # d := w * (b0 - b1)
            # smoothL1(x) = 0.5 * x^2 if |x| < 1
            # |x| - 0.5 otherwise
            def smoothL1(x):
                if abs(x) < 1:
                    return 0.5 * x * x
                else:
                    return abs(x) - 0.5

            bbox_loss = np.zeros(labels.shape[0])  # 边界框损失
            for i in np.where(labels > 0 )[0]:
                indices = np.where(bbox_inside_weights[i,:] != 0)[0]
                bbox_loss[i] = sum(bbox_outside_weights[i,indices] * [smoothL1(x) \
                    for x in bbox_inside_weights[i,indices] * (bbox_pred[i,indices] - bbox_target[i,indices])])
            loss += bbox_loss  # 两者损失相加

        # 筛选出损失比较大的返回
        blobs = get_ohem_minibatch(loss, rois, labels, bbox_target, \
            bbox_inside_weights, bbox_outside_weights)

		# 给top blob赋值
        for blob_name, blob in blobs.iteritems():
            top_ind = self._name_to_top_map[blob_name]
            # Reshape net's input blobs
            top[top_ind].reshape(*(blob.shape))
            # Copy data into net's input blobs
            top[top_ind].data[...] = blob.astype(np.float32, copy=False)

2.2 get_ohem_minibatch

# 获取OHEM训练的batch
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
                       bbox_inside_weights=None, bbox_outside_weights=None):
    """Given rois and their loss, construct a minibatch using OHEM."""
    loss = np.array(loss)
	
	# 使用NMS过滤检测框
    if cfg.TRAIN.OHEM_USE_NMS:	# NMS thresh=0.7
        # Do NMS using loss for de-dup and diversity
        keep_inds = []
        nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH  # 0.7
        source_img_ids = [roi[0] for roi in rois] # 0或1(背景与前景)
        for img_id in np.unique(source_img_ids):
            for label in np.unique(labels):
                sel_indx = np.where(np.logical_and(labels == label, \
                                    source_img_ids == img_id))[0]
                if not len(sel_indx):
                    continue
                boxes = np.concatenate((rois[sel_indx, 1:],
                        loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
                keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])

        hard_keep_inds = select_hard_examples(loss[keep_inds])  # 按照损失排序选择样本
        hard_inds = np.array(keep_inds)[hard_keep_inds]  # 最后保留下来的困难样本索引
    else:
        hard_inds = select_hard_examples(loss)

    blobs = { 
   'rois_hard': rois[hard_inds, :].copy(),
             'labels_hard': labels[hard_inds].copy()}
    if bbox_targets is not None:
        assert cfg.TRAIN.BBOX_REG
        blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
        blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
        blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()

    return blobs

def select_hard_examples(loss):
    """Select hard rois."""
    # Sort and select top hard examples.
    sorted_indices = np.argsort(loss)[::-1]
    hard_keep_inds = sorted_indices[0:np.minimum(len(loss), cfg.TRAIN.BATCH_SIZE)]
    # (explore more ways of selecting examples in this function; e.g., sampling)

    return hard_keep_inds
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139136.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 3分钟告诉你如何成为一名黑客?|零基础到黑客入门指南,你只需要掌握这五点能力

    3分钟告诉你如何成为一名黑客?|零基础到黑客入门指南,你只需要掌握这五点能力三分钟带各位揭秘黑客究竟是什么,以及想要成为黑客都需要具备哪些能力?

    2022年6月4日
    47
  • CentOS下Apache配置虚拟主机[通俗易懂]

    CentOS下Apache配置虚拟主机[通俗易懂]有时候我们往往一个服务器会运行多个应用,此时就需要给每个应用创建虚拟主机了,这里我创建三个文件夹,分别运行三个页面:

    2022年9月15日
    0
  • JAVA XML转对象 对象转XML

    JAVA XML转对象 对象转XML在网上看了许多XML跟Obj互相转换的Demo,但是都很复杂,现在推荐一个极度简单好理解的XML和Obj互转的例子:JacksonXML,只需要简单的几个注解就能完成XML和Obj的相互转换假设有如下xml报文:<?xmlversion=”1.0″encoding=”utf-8″?><msgbody><StringList>&…

    2022年7月21日
    11
  • pycharm 2022激活码[最新免费获取]「建议收藏」

    (pycharm 2022激活码)这是一篇idea技术相关文章,由全栈君为大家提供,主要知识点是关于2021JetBrains全家桶永久激活码的内容IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.net/100143.html1STL5S9V8F-eyJsa…

    2022年3月27日
    279
  • 【JAVA】Java学习路线图「建议收藏」

    【JAVA】Java学习路线图「建议收藏」怎么学习Java,这是很多新手经常会问我的问题,现在我简单描述下一个Java初学者到就业要学到的一些东西:    首先要明白Java体系设计到得三个方面:J2SE,J2EE,J2ME(KJAVA)。J2SE,Java2PlatformStandardEdition,我们经常说到的JDK,就主要指的这个,它是三者的基础,属于桌面级应用开发,这部分如果学得好很容易拓展J2EE和J2ME。

    2022年5月15日
    31
  • kali如何更换源(怎样换一个kali源)

    KaliLinux的换源和更新1.修改源文件(需要用root权限)[plain]viewplaincopyvim /etc/apt/sources.list  2.这里修改两个我认为还好的源,因为每个地方不同,选择源的时候建议使用一些常用的吧。比如:阿里云源,中科大源之类的官方源更新的速度太慢了,所以我注释掉了,只使用两

    2022年4月12日
    159

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号