【tensorflow】MTCNN网络基本函数bbox_ohem&landmark_ohem()

【tensorflow】MTCNN网络基本函数bbox_ohem&landmark_ohem()tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来importtensorflowastfimportnumpyasnpa=tf.constant([1,2,3,4])b=tf.square(a)withtf.Session()assess:print(“b:%s”%sess.run(b))#b:[14916]…

大家好,又见面了,我是你们的朋友全栈君。

tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来

import tensorflow as tf
import numpy as np
a = tf.constant([1,2,3,4])
b = tf.square(a)
with tf.Session() as sess:
    print("b:%s" % sess.run(b))
# b:[ 1  4  9 16]
import numpy as np
import tensorflow as tf
def bbox_ohem(bbox_pred,bbox_target,label):
    '''
    :param bbox_pred:
    :param bbox_target:
    :param label: class label
    :return: mean euclidean loss for all the pos and part examples
    '''
    zeros_index = tf.zeros_like(label, dtype=tf.float32)
    ones_index = tf.ones_like(label, dtype=tf.float32)
    #获取pos样本和part样本
    valid_inds = tf.where(tf.equal(tf.abs(label),1),ones_index,zeros_index)
    #(batch,)
    #计算平方和(按行)tf.square(bbox_pred-bbox_target): 求每个数的平方值
    square_error = tf.square(bbox_pred-bbox_target)
    square_error = tf.reduce_sum(square_error,axis=1)
    with tf.Session() as sess:
        print("bbox_pred-bbox_target:%s"%(sess.run(bbox_pred-bbox_target)))
        print("square_error:%s" % (sess.run(square_error)))
    # 计算pos样本和part样本的数量
    num_valid = tf.reduce_sum(valid_inds)
    keep_num = tf.cast(num_valid, dtype=tf.int32)
    # 去掉neg样本和landmark样本的平方和
    square_error = square_error*valid_inds
    # 获取前K个样本的索引,K为pos和part样本的数量
    _, k_index = tf.nn.top_k(square_error, k=keep_num)
    # 将所有pos样本和part样本的平方和提取出来
    square_error = tf.gather(square_error, k_index)
    # 返回均值
    return tf.reduce_mean(square_error)

bbox_pred = tf.random_uniform([2,4],10,100,seed = 100)
bbox_target = tf.random_uniform([2,4],15,150,seed = 100)
with tf.Session() as sess:
    print("cls_prob:%s"%(sess.run(bbox_pred)))
label = np.array([1,0])
bbox_ohem(bbox_pred,bbox_target,label)

在这里插入图片描述

landmark_ohem:作用就是返回landmark的损失,用的是landmark样本。

def landmark_ohem(landmark_pred,landmark_target,label):
    '''

    :param landmark_pred:
    :param landmark_target:
    :param label:
    :return: mean euclidean loss
    '''
    #keep label =-2  then do landmark detection
    ones = tf.ones_like(label,dtype=tf.float32)
    zeros = tf.zeros_like(label,dtype=tf.float32)
    valid_inds = tf.where(tf.equal(label,-2),ones,zeros)
    square_error = tf.square(landmark_pred-landmark_target)
    square_error = tf.reduce_sum(square_error,axis=1)
    num_valid = tf.reduce_sum(valid_inds)
    #keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
    keep_num = tf.cast(num_valid, dtype=tf.int32)
    square_error = square_error*valid_inds
    _, k_index = tf.nn.top_k(square_error, k=keep_num)
    square_error = tf.gather(square_error, k_index)
    return tf.reduce_mean(square_error)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139359.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • pycharm的_pycharm conda

    pycharm的_pycharm conda不知道朋友们用过maven没有,使用JAVA编程的人应该有人用过这个有趣的东西,JAVA导包是容易的,然而,懒是没有极限了,所以maven出来了,一个丰满的开发包仓库,不需要你再去哪儿找找找。但是这又算得了什么,我们伟大的Python怎么可能弱,pip包安装管理器就是这样的存在,他使得安装pymodel变得和在linux下安装软件一样容易,只要简单的一句pipinstallsimplename

    2022年8月28日
    2
  • effective C++ 读书笔记 条款08「建议收藏」

    effective C++ 读书笔记 条款08

    2022年2月7日
    44
  • C语言qsort函数用法

    C语言qsort函数用法qsort函数简介   排序方法有很多种:选择排序,冒泡排序,归并排序,快速排序等。看名字都知道快速排序是目前公认的一种比较好的排序算法。因为他速度很快,所以系统也在库里实现这个算法,便于我们的使用。这就是qsort函数(全称quicksort)。它是ANSIC标准中提供的,其声明在stdlib.h文件中,是根据二分法写的,其时间复杂度为n*log(n)  功能:

    2022年6月23日
    26
  • 集成灶功能图标解释_为什么不建议买集成灶

    集成灶功能图标解释_为什么不建议买集成灶GitLab Auto DevOps功能与Kubernetes集成教程

    2022年4月22日
    68
  • pycharm怎么配置tensorflow环境_pycharm环境搭建

    pycharm怎么配置tensorflow环境_pycharm环境搭建Pycharm安装并搭建Tensorflow开发环境下载并安装pycharm1.下载2.pycharm配置python环境安装tensorflow1.输入清华仓库镜像2.创建tensorflow环境3.启动tensorflow环境4.安装cpu版本的TensorFlow5.测试TensorFlowPycharm中配置TensorFlow环境在操作之前先安装好python环境,我是安装的Anaconda,Anaconda下载安装教程可参考:https://blog.csdn.net/Chen_Meng_

    2022年8月26日
    6
  • 光流法小结[通俗易懂]

    光流法小结[通俗易懂]1.定义空间运动物体在观察成像平面上的像素运动的瞬时速度,是利用图像序列中像素在时间域上的变化以及相邻帧之间的相关性来找到上一帧跟当前帧之间存在的对应关系,从而计算出相邻帧之间物体的运动信息的一种方法。也就是说,由空间域到图像平面的投影。而通俗来讲,把图像中的每一个点的瞬时速度和方向找出来就是光流。2.光流有什么用通过光流判断物体距离我们的远近。一般而言,远景的物体相对来说光流较小,而近景物体

    2022年7月23日
    10

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号