OHEM代码梳理[通俗易懂]

OHEM代码梳理[通俗易懂]传送门:相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结代码地址:OHEM1.前言有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得…

大家好,又见面了,我是你们的朋友全栈君。

  • 传送门
  1. 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
  2. 代码地址:OHEM

1. 前言

有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。

layer {
  name: "hard_roi_mining"
  type: "Python"
  bottom: "cls_prob_readonly"
  bottom: "bbox_pred_readonly"
  bottom: "rois"
  bottom: "labels"
  bottom: "bbox_targets"
  bottom: "bbox_inside_weights"
  bottom: "bbox_outside_weights"
  top: "rois_hard"
  top: "labels_hard"
  top: "bbox_targets_hard"
  top: "bbox_inside_weights_hard"
  top: "bbox_outside_weights_hard"
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  propagate_down: false
  python_param {
    module: "roi_data_layer.layer"
    layer: "OHEMDataLayer"
    param_str: "'num_classes': 6" #6
  }
}

2. OHEM代码简单梳理

2.1 OHEMDataLayer

class OHEMDataLayer(caffe.Layer):
    """Online Hard-example Mining Layer."""
    def setup(self, bottom, top):
        """Setup the OHEMDataLayer."""

        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)

        self._num_classes = layer_params['num_classes']	 # 获取分类数目

        self._name_to_bottom_map = { 
     # 将bottom的blob名称与index使用dict关联
            'cls_prob_readonly': 0,
            'bbox_pred_readonly': 1,
            'rois': 2,
            'labels': 3}

        if cfg.TRAIN.BBOX_REG:  # 有边界框回归
            self._name_to_bottom_map['bbox_targets'] = 4
            self._name_to_bottom_map['bbox_loss_weights'] = 5

        self._name_to_top_map = { 
   }  # 同理top的blob名称也要与index关联起来
        ……
	
	# 前向传播函数
    def forward(self, bottom, top):
        """Compute loss, select RoIs using OHEM. Use RoIs to get blobs and copy them into this layer's top blob vector."""

        cls_prob = bottom[0].data  # 获取对应bottom的数据
        bbox_pred = bottom[1].data
        rois = bottom[2].data
        labels = bottom[3].data
        if cfg.TRAIN.BBOX_REG:
            bbox_target = bottom[4].data
            bbox_inside_weights = bottom[5].data
            bbox_outside_weights = bottom[6].data
        else:
            bbox_target = None
            bbox_inside_weights = None
            bbox_outside_weights = None

        flt_min = np.finfo(float).eps
        # 计算分类的损失
        loss = [ -1 * np.log(max(x, flt_min)) \
            for x in [cls_prob[i,label] for i, label in enumerate(labels)]]

        # 计算边界框回归的损失,并且将两个损失加起来
        if cfg.TRAIN.BBOX_REG:
            # bounding-box regression loss
            # d := w * (b0 - b1)
            # smoothL1(x) = 0.5 * x^2 if |x| < 1
            # |x| - 0.5 otherwise
            def smoothL1(x):
                if abs(x) < 1:
                    return 0.5 * x * x
                else:
                    return abs(x) - 0.5

            bbox_loss = np.zeros(labels.shape[0])  # 边界框损失
            for i in np.where(labels > 0 )[0]:
                indices = np.where(bbox_inside_weights[i,:] != 0)[0]
                bbox_loss[i] = sum(bbox_outside_weights[i,indices] * [smoothL1(x) \
                    for x in bbox_inside_weights[i,indices] * (bbox_pred[i,indices] - bbox_target[i,indices])])
            loss += bbox_loss  # 两者损失相加

        # 筛选出损失比较大的返回
        blobs = get_ohem_minibatch(loss, rois, labels, bbox_target, \
            bbox_inside_weights, bbox_outside_weights)

		# 给top blob赋值
        for blob_name, blob in blobs.iteritems():
            top_ind = self._name_to_top_map[blob_name]
            # Reshape net's input blobs
            top[top_ind].reshape(*(blob.shape))
            # Copy data into net's input blobs
            top[top_ind].data[...] = blob.astype(np.float32, copy=False)

2.2 get_ohem_minibatch

# 获取OHEM训练的batch
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
                       bbox_inside_weights=None, bbox_outside_weights=None):
    """Given rois and their loss, construct a minibatch using OHEM."""
    loss = np.array(loss)
	
	# 使用NMS过滤检测框
    if cfg.TRAIN.OHEM_USE_NMS:	# NMS thresh=0.7
        # Do NMS using loss for de-dup and diversity
        keep_inds = []
        nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH  # 0.7
        source_img_ids = [roi[0] for roi in rois] # 0或1(背景与前景)
        for img_id in np.unique(source_img_ids):
            for label in np.unique(labels):
                sel_indx = np.where(np.logical_and(labels == label, \
                                    source_img_ids == img_id))[0]
                if not len(sel_indx):
                    continue
                boxes = np.concatenate((rois[sel_indx, 1:],
                        loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
                keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])

        hard_keep_inds = select_hard_examples(loss[keep_inds])  # 按照损失排序选择样本
        hard_inds = np.array(keep_inds)[hard_keep_inds]  # 最后保留下来的困难样本索引
    else:
        hard_inds = select_hard_examples(loss)

    blobs = { 
   'rois_hard': rois[hard_inds, :].copy(),
             'labels_hard': labels[hard_inds].copy()}
    if bbox_targets is not None:
        assert cfg.TRAIN.BBOX_REG
        blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
        blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
        blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()

    return blobs

def select_hard_examples(loss):
    """Select hard rois."""
    # Sort and select top hard examples.
    sorted_indices = np.argsort(loss)[::-1]
    hard_keep_inds = sorted_indices[0:np.minimum(len(loss), cfg.TRAIN.BATCH_SIZE)]
    # (explore more ways of selecting examples in this function; e.g., sampling)

    return hard_keep_inds
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139136.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • idea免费激活码2021(JetBrains全家桶)[通俗易懂]

    (idea免费激活码2021)最近有小伙伴私信我,问我这边有没有免费的intellijIdea的激活码,然后我将全栈君台教程分享给他了。激活成功之后他一直表示感谢,哈哈~https://javaforall.net/100143.htmlIntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,上面是详细链接哦~MLZP…

    2022年3月20日
    53
  • Python udp编程_python socket udp

    Python udp编程_python socket udpTCP是建立可靠连接,并且通信双方都可以以流的形式发送数据。相对TCP,UDP则是面向无连接的协议。使用UDP协议时,不需要建立连接,只需要知道对方的IP地址和端口号,就可以直接发数据包。但是,能不能到达就不知道了。虽然用UDP传输数据不可靠,但它的优点是和TCP比,速度快,对于不要求可靠到达的数据,就可以使用UDP协议。我们来看看如何通过UDP协议传输数据。和TCP类似,

    2025年10月2日
    5
  • shuffle单级互连网络_如何看论文

    shuffle单级互连网络_如何看论文ShuffleNetShuffleNet:AnExtremelyEfficientConvolutionalNeuralNetworkforMobileDevices原文地址:ShuffleNet代码:-TensorFlow-CaffeAbstract论文介绍一个效率极高的CNN架构ShuffleNet,专门应用于计算力受限的移动设备。新

    2025年10月16日
    3
  • 微信小程序 宠物论坛1[通俗易懂]

    微信小程序 宠物论坛1[通俗易懂]微信小程序宠物论坛(1)一个简单的论坛包括以下几个方面:登录模块发帖模块首页模块帖子详情模块搜索模块个人主页模块下面将从这6个方面介绍如何用微信小程序开发一个简单的论坛。1、登录模块先看界面图打开小程序首先看到这个界面,之后我们点击头像便完成授权登录。JS部分//index.js//获取应用实例constapp=getApp()constdb=wx.cloud.database()Page({data:{motto:’欢迎来到宠物论坛

    2022年10月7日
    4
  • linux的通配符有哪些,Linux通配符「建议收藏」

    linux的通配符有哪些,Linux通配符「建议收藏」Linux通配符说明:通配符是bash的内置功能,几乎适用于所有Linux命令。*匹配任意(0个或多个)字符或字符串,包括空字符串。?匹配任意1个字符,有且只有一个字符。[abcd]匹配abcd中任何一个字符,abcd也可以是其他任意不连续字符。[a-z]匹配a到z之间的任意一个字符,字符前后要连续,也可以用连续数字,即[1-9]。[!abcd]表示不匹配括号里面的任何一个字符…

    2022年9月19日
    2
  • oracle正则表达式匹配中文

    oracle正则表达式匹配中文oracle正则表达式regexp_substr、regexp_like、regexp_replace是无法像其他正则表达式一样用[\u4e00-\u9fa5]来匹配中文的。所以,我们需要用另一种方式来实现oracle正则表达式匹配中文。我们需要用到oracle的内置函数UNISTR(str):ASCIISTR语法:asciistr(str)功能:返回字符串的规则表现形式,英文和数字变

    2022年6月28日
    34

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号