如何正确的理解RPN网络的train和test[通俗易懂]

如何正确的理解RPN网络的train和test[通俗易懂]刚开始学FasterRCNN时,遇到这么一个困惑不知其他人有没有:RPN网络在程序中的训练是如何进行的?它都训练了网络中的哪些部分?其实这些我们如果不看源码都很难真正理解!我们以Faster-RCNN_TF的源码为例,以下代码取自./lib/networks/VGGnet_train.py#=========RPN============#以下代码的先后顺序我调整了一下,便…

大家好,又见面了,我是你们的朋友全栈君。

刚开始学Faster RCNN时,遇到些困惑不知其他人有没有:

1、RPN网络训练的输出是什么?
2、RPN网络在train中的作用是什么?
3、RPN网络在test中的作用是什么?
其实这些我们如果不看源码都很难真正理解!

以Faster-RCNN_TF的源码为例,以下代码取自./lib/networks/VGGnet_train.py

 #========= RPN ============
 #以下代码的先后顺序我调整了一下,便于理解
 (self.feed('conv5_3')
     .conv(3,3,512,1,1,name='rpn_conv/3x3')
     .conv(1,1,len(anchor_scales)*3*2 ,1 , 1, padding='VALID', relu = False, name='rpn_cls_score'))

 (self.feed('rpn_conv/3x3')
     .conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='VALID', relu = False, name='rpn_bbox_pred'))
     .anchor_target_layer(_feat_stride, anchor_scales, name = 'rpn-data' ))

重点

anchor_target_layer的返回值’rpn-data’,这是一个字典
key分别是:rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights

rpn_labels
是 [1,1,A*height,width],如果把它reshape成[1,A,height,width]会更好理解,即feature map上每一点
都是一个anchor,每个anchor对应A个bbox,如果一个bbox与gt_box的重叠度大于0.7(其实还有一个条件),就认为这个bbox包含一个前景,则
rpn_labels 矩阵中相应位置就设置为1。
gt_box的label不能直接用来做训练的目标(target),在训练中使用rpn_labels作为训练的目标
gt_box的唯一作用就在于判断产生的共A*W*H个bbox哪些属于前景,哪些不属于,将那些属于前景的bbox设置为训练目标去训练rpn_cls_score_reshape。
在test中,正好相反,训练好的网络会产生一个rpn_cls_score_reshape,它可以转化成一个[1,A,height,width]的矩阵
#proposal_layer 产生的[1,A,height,width]个bbox哪些属于前景,哪些属于背景。我们会把属于前景的挑出来,
按照得分排序,取前300个输入后面的fc层,fc层会产生两个输出:
一个是cls_score,用于判断bbox中物体的类型
另一个是bbox_pred,用于微调bbox,使其向gt_box进一步靠近(由于bbox都是从anchor产生的,他们不会和gt_box重合,还需要进一步微调)

rpn_bbox_targets
根据 rpn_labels 我们已经可以挑选出300个bbox,这些bbox都是在[1,W,H,A*4]中根据与gt_box的重合程度挑选出来的,与gt_box并不相同,有一些偏差,这些偏差表示为[dx,dy,dw,dh],这就是rpn_bbox_targets。
因为传进后面全卷积网络的是bbox,与gt_boxes不完全重合,为了使最终的结果更加接近gt_box,还需要进一步微调
而全卷积层的输出bbox_pred就是用于微调的,rpn_bbox_targets就是它训练的目标(target)
损失函数的计算:

# RPN
# classification loss
rpn_cls_score = tf.reshape(self.net.get_output('rpn_cls_score_reshape'),[-1,2])
rpn_label = tf.reshape(self.net.get_output('rpn-data')[0],[-1])
rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score,tf.where(tf.not_equal(rpn_label,-1))),[-1,2])
rpn_label = tf.reshape(tf.gather(rpn_label,tf.where(tf.not_equal(rpn_label,-1))),[-1])
rpn_cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))

# bounding box regression L1 loss
rpn_bbox_pred = self.net.get_output('rpn_bbox_pred')
rpn_bbox_targets = tf.transpose(self.net.get_output('rpn-data')[1],[0,2,3,1])
rpn_bbox_inside_weights = tf.transpose(self.net.get_output('rpn-data')[2],[0,2,3,1])
rpn_bbox_outside_weights = tf.transpose(self.net.get_output('rpn-data')[3],[0,2,3,1])

rpn_smooth_l1 = self._modified_smooth_l1(3.0, rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights)
rpn_loss_box = tf.reduce_mean(tf.reduce_sum(rpn_smooth_l1, reduction_indices=[1, 2, 3]))

其余代码:

# Loss of rpn_cls & rpn_boxes

(self.feed('rpn_conv/3x3')
     .conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='VALID', relu = False, name='rpn_bbox_pred'))

#========= RoI Proposal ============
(self.feed('rpn_cls_score')
     .reshape_layer(2,name = 'rpn_cls_score_reshape')
     .softmax(name='rpn_cls_prob'))

(self.feed('rpn_cls_prob')
     .reshape_layer(len(anchor_scales)*3*2,name = 'rpn_cls_prob_reshape'))

(self.feed('rpn_cls_prob_reshape','rpn_bbox_pred','im_info')
     .proposal_layer(_feat_stride, anchor_scales, 'TRAIN',name = 'rpn_rois'))

(self.feed('rpn_rois','gt_boxes')
     .proposal_target_layer(n_classes,name = 'roi-data'))


#========= RCNN ============
(self.feed('conv5_3', 'roi-data')
     .roi_pool(7, 7, 1.0/16, name='pool_5')
     .fc(4096, name='fc6')
     .dropout(0.5, name='drop6')
     .fc(4096, name='fc7')
     .dropout(0.5, name='drop7')
     .fc(n_classes, relu=False, name='cls_score')
     .softmax(name='cls_prob'))

(self.feed('drop7')
     .fc(n_classes*4, relu=False, name='bbox_pred'))
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152315.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • nextline函数_nextLine()和next()的区别和使用方法

    nextline函数_nextLine()和next()的区别和使用方法最近在笔试,刷剑指Offer时,都是只需要把方法实现了就行。但是!!!笔试时候会发现,大部分会要求你把main函数也code出来,真是醉了,第一次笔试时候搞的晕乎乎的…..废话不多说,那么在写输入输出中,肯定要用到Scanner类了,其中不少都需要读取一个整数或者一个整型数组。当读入整数时(以int为例),直接就nextInt()就好,可是当读入一个整型数组时(数字之间用空格隔开),就涉及到用…

    2022年6月8日
    30
  • docker 修改容器时间_jenkins docker持续集成

    docker 修改容器时间_jenkins docker持续集成前言用docker搭建的Jenkins环境时间显示和我们本地时间相差8个小时,需修改容器内部的系统时间查看时间查看系统时间date-R进入docker容器内部,查看容器时间dockere

    2022年7月29日
    4
  • [转载]windows phone 墓碑化(9)

    [转载]windows phone 墓碑化(9)

    2021年8月20日
    46
  • vue与jquery的区别_vue 3

    vue与jquery的区别_vue 31.jquery介绍:想必大家都用过jquery吧,这个曾经也是现在依然最流行的web前端js库,可是现在无论是国内还是国外他的使用率正在渐渐被其他的js库所代替,随着浏览器厂商对HTML5规范统一遵循以及ECMA6在浏览器端的实现,jquery的使用率将会越来越低2.vue介绍:vue是一个兴起的前端js库,是一个精简的MVVM。从技术角度讲,Vue.js专注于MVVM模型的ViewM…

    2022年10月16日
    0
  • 阿里笔试题(2015)持续更新中

    阿里笔试题(2015)持续更新中第一次做阿里笔试题,除了ACM题之外从来没有做过校招网络题呀,完全是裸考,总体感觉吧,对于我来说,感觉时间不够用,不是题不会,感觉时间紧,大脑很混乱,总结这一次的笔试题废话不多说,直接上题和答案平均每个人逗留时间为20分钟,那么开场前20分钟一共来了400人,且有20个人逗留时间已经到,应该容纳400人双向循环列表,从任何一个元素开始可以遍历全部元素先和后面的元素相

    2022年5月24日
    36
  • 探寻京东云核心竞争力的源泉「建议收藏」

    探寻京东云核心竞争力的源泉「建议收藏」云计算服务提供商的核心竞争力有哪些?除了技术、产品与服务之外,基础设施亦是不可忽视的一大因素。之所以会如此,是因为云计算是一个堪称“三高”的市场:高技术壁垒、高投资投入、高市场增长,云服务提供商需要保持长期投入,通过规模效应来实现成本优势,从而吸引更多用户采用其相关服务与产品。数据中心等基础设施的建设是云服务提供商实现持续成长的关键所在。数据不会骗人。根据咨询机构SynergyRese…

    2022年10月8日
    0

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号