OHEM代码梳理[通俗易懂]

OHEM代码梳理[通俗易懂]传送门:相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结代码地址:OHEM1.前言有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得…

大家好,又见面了,我是你们的朋友全栈君。

  • 传送门
  1. 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
  2. 代码地址:OHEM

1. 前言

有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。

layer {
  name: "hard_roi_mining"
  type: "Python"
  bottom: "cls_prob_readonly"
  bottom: "bbox_pred_readonly"
  bottom: "rois"
  bottom: "labels"
  bottom: "bbox_targets"
  bottom: "bbox_inside_weights"
  bottom: "bbox_outside_weights"
  top: "rois_hard"
  top: "labels_hard"
  top: "bbox_targets_hard"
  top: "bbox_inside_weights_hard"
  top: "bbox_outside_weights_hard"
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  python_param {
    module: "roi_data_layer.layer"
    layer: "OHEMDataLayer"
    param_str: "'num_classes': 6" #6
  }
}

2. OHEM代码简单梳理

2.1 OHEMDataLayer

class OHEMDataLayer(caffe.Layer):
    """Online Hard-example Mining Layer."""
    def setup(self, bottom, top):
        """Setup the OHEMDataLayer."""

        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)

        self._num_classes = layer_params['num_classes']	 # 获取分类数目

        self._name_to_bottom_map = { 
     # 将bottom的blob名称与index使用dict关联
            'cls_prob_readonly': 0,
            'bbox_pred_readonly': 1,
            'rois': 2,
            'labels': 3}

        if cfg.TRAIN.BBOX_REG:  # 有边界框回归
            self._name_to_bottom_map['bbox_targets'] = 4
            self._name_to_bottom_map['bbox_loss_weights'] = 5

        self._name_to_top_map = { 
   }  # 同理top的blob名称也要与index关联起来
        ……
	
	# 前向传播函数
    def forward(self, bottom, top):
        """Compute loss, select RoIs using OHEM. Use RoIs to get blobs and copy them into this layer's top blob vector."""

        cls_prob = bottom[0].data  # 获取对应bottom的数据
        bbox_pred = bottom[1].data
        rois = bottom[2].data
        labels = bottom[3].data
        if cfg.TRAIN.BBOX_REG:
            bbox_target = bottom[4].data
            bbox_inside_weights = bottom[5].data
            bbox_outside_weights = bottom[6].data
        else:
            bbox_target = None
            bbox_inside_weights = None
            bbox_outside_weights = None

        flt_min = np.finfo(float).eps
        # 计算分类的损失
        loss = [ -1 * np.log(max(x, flt_min)) \
            for x in [cls_prob[i,label] for i, label in enumerate(labels)]]

        # 计算边界框回归的损失,并且将两个损失加起来
        if cfg.TRAIN.BBOX_REG:
            # bounding-box regression loss
            # d := w * (b0 - b1)
            # smoothL1(x) = 0.5 * x^2 if |x| < 1
            # |x| - 0.5 otherwise
            def smoothL1(x):
                if abs(x) < 1:
                    return 0.5 * x * x
                else:
                    return abs(x) - 0.5

            bbox_loss = np.zeros(labels.shape[0])  # 边界框损失
            for i in np.where(labels > 0 )[0]:
                indices = np.where(bbox_inside_weights[i,:] != 0)[0]
                bbox_loss[i] = sum(bbox_outside_weights[i,indices] * [smoothL1(x) \
                    for x in bbox_inside_weights[i,indices] * (bbox_pred[i,indices] - bbox_target[i,indices])])
            loss += bbox_loss  # 两者损失相加

        # 筛选出损失比较大的返回
        blobs = get_ohem_minibatch(loss, rois, labels, bbox_target, \
            bbox_inside_weights, bbox_outside_weights)

		# 给top blob赋值
        for blob_name, blob in blobs.iteritems():
            top_ind = self._name_to_top_map[blob_name]
            # Reshape net's input blobs
            top[top_ind].reshape(*(blob.shape))
            # Copy data into net's input blobs
            top[top_ind].data[...] = blob.astype(np.float32, copy=False)

2.2 get_ohem_minibatch

# 获取OHEM训练的batch
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
                       bbox_inside_weights=None, bbox_outside_weights=None):
    """Given rois and their loss, construct a minibatch using OHEM."""
    loss = np.array(loss)
	
	# 使用NMS过滤检测框
    if cfg.TRAIN.OHEM_USE_NMS:	# NMS thresh=0.7
        # Do NMS using loss for de-dup and diversity
        keep_inds = []
        nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH  # 0.7
        source_img_ids = [roi[0] for roi in rois] # 0或1(背景与前景)
        for img_id in np.unique(source_img_ids):
            for label in np.unique(labels):
                sel_indx = np.where(np.logical_and(labels == label, \
                                    source_img_ids == img_id))[0]
                if not len(sel_indx):
                    continue
                boxes = np.concatenate((rois[sel_indx, 1:],
                        loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
                keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])

        hard_keep_inds = select_hard_examples(loss[keep_inds])  # 按照损失排序选择样本
        hard_inds = np.array(keep_inds)[hard_keep_inds]  # 最后保留下来的困难样本索引
    else:
        hard_inds = select_hard_examples(loss)

    blobs = { 
   'rois_hard': rois[hard_inds, :].copy(),
             'labels_hard': labels[hard_inds].copy()}
    if bbox_targets is not None:
        assert cfg.TRAIN.BBOX_REG
        blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
        blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
        blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()

    return blobs

def select_hard_examples(loss):
    """Select hard rois."""
    # Sort and select top hard examples.
    sorted_indices = np.argsort(loss)[::-1]
    hard_keep_inds = sorted_indices[0:np.minimum(len(loss), cfg.TRAIN.BATCH_SIZE)]
    # (explore more ways of selecting examples in this function; e.g., sampling)

    return hard_keep_inds
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139136.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月30日 上午11:00
下一篇 2022年5月30日 上午11:16


相关推荐

  • smarty怎么用_item怎么用

    smarty怎么用_item怎么用1、简介含义:Smarty是PHP的一个引擎模板,可以更好的进行逻辑与显示的分离,即我们常说的MVC,这个引擎的作用就是将C分离出来。环境需求:PHP5.2或者更高版本我使用的环境是:PHP5.

    2022年8月5日
    8
  • 全新自适应地址发布页HTML源码【手机端】【pc端】

    全新自适应地址发布页HTML源码【手机端】【pc端】完整源码 点击下载 密码 ea9f

    2026年3月26日
    2
  • 通过QXDM锁BAND_不root怎么锁band

    通过QXDM锁BAND_不root怎么锁band1、通过QXDM锁频QXDM工具View->New->Common->NVBrowser:NV(NonVoliatile)参数就是保存在终端上的非易失参数,可以通过view中的NVBrowser来进行查看和修改。这些信息由厂家固化在终端内部,一般不允许用户修改。同时,可以通过NVBrowser对终端进行Offline(掉电重启)操作。其中06828LTEBCconfig可以配置终端支持的band信息,将该项的值读出来(默认读出来为十进制)转化为二进制,..

    2026年4月17日
    5
  • Java中的重载与重写的区别

    Java中的重载与重写的区别java中的重载与重写的区别1、重载发生在本类,重写发生在父类与子类之间;2、重载的方法名必须相同,重写的方法名相同且返回值类型必须相同;3、重载的参数列表不同,重写的参数列表必须相同。重载(Overloading)重载发生在本类,方法名相同,参数列表不同,与返回值无关,只和方法名,参数列表,参数的类型有关.重载(Overload):首先是位于一个类之中或者其子类中,具有相同的方法名,但是方法的参数不同,返回值类型可以相同也可以不同。重载的特征(1):方法名必须相同(2):方法的参数列表一

    2022年7月7日
    29
  • httprunner(3)用脚手架快速搭建项目[通俗易懂]

    httprunner(3)用脚手架快速搭建项目[通俗易懂]前言如何快速搭建一个httprunner项目呢?我们可以使用脚手架,脚手架就是自动地创建一些目录,形成一个项目的架构,不需要我们再手动的去创建查看创建新项目的命令先来查看一下帮助命令httpr

    2022年7月29日
    14
  • HTML5期末大作业:女装服装商城网站设计——女装服装商城(11页) HTML+CSS+JavaScript 学生DW网页设计作业成品 web课程设计网页

    HTML5期末大作业:女装服装商城网站设计——女装服装商城(11页) HTML+CSS+JavaScript 学生DW网页设计作业成品 web课程设计网页HTML5期末大作业:女装服装商城网站设计——女装服装商城(11页)HTML+CSS+JavaScript学生DW网页设计作业成品web课程设计网页常见网页设计作业题材有个人、美食、公司、学校、旅游、电商、宠物、电器、茶叶、家居、酒店、舞蹈、动漫、明星、服装、体育、化妆品、物流、环保、书籍、婚纱、军事、游戏、节日、戒烟、电影、摄影、文化、家乡、鲜花、礼品、汽车、其他等网页设计题目,A+水平作业,可满足大学生网页大作业网页设

    2025年8月31日
    5

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号