【tensorflow】MTCNN网络基本函数bbox_ohem&landmark_ohem()

【tensorflow】MTCNN网络基本函数bbox_ohem&landmark_ohem()tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来importtensorflowastfimportnumpyasnpa=tf.constant([1,2,3,4])b=tf.square(a)withtf.Session()assess:print(“b:%s”%sess.run(b))#b:[14916]…

大家好,又见面了,我是你们的朋友全栈君。

tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来

import tensorflow as tf
import numpy as np
a = tf.constant([1,2,3,4])
b = tf.square(a)
with tf.Session() as sess:
    print("b:%s" % sess.run(b))
# b:[ 1  4  9 16]
import numpy as np
import tensorflow as tf
def bbox_ohem(bbox_pred,bbox_target,label):
    '''
    :param bbox_pred:
    :param bbox_target:
    :param label: class label
    :return: mean euclidean loss for all the pos and part examples
    '''
    zeros_index = tf.zeros_like(label, dtype=tf.float32)
    ones_index = tf.ones_like(label, dtype=tf.float32)
    #获取pos样本和part样本
    valid_inds = tf.where(tf.equal(tf.abs(label),1),ones_index,zeros_index)
    #(batch,)
    #计算平方和(按行)tf.square(bbox_pred-bbox_target): 求每个数的平方值
    square_error = tf.square(bbox_pred-bbox_target)
    square_error = tf.reduce_sum(square_error,axis=1)
    with tf.Session() as sess:
        print("bbox_pred-bbox_target:%s"%(sess.run(bbox_pred-bbox_target)))
        print("square_error:%s" % (sess.run(square_error)))
    # 计算pos样本和part样本的数量
    num_valid = tf.reduce_sum(valid_inds)
    keep_num = tf.cast(num_valid, dtype=tf.int32)
    # 去掉neg样本和landmark样本的平方和
    square_error = square_error*valid_inds
    # 获取前K个样本的索引,K为pos和part样本的数量
    _, k_index = tf.nn.top_k(square_error, k=keep_num)
    # 将所有pos样本和part样本的平方和提取出来
    square_error = tf.gather(square_error, k_index)
    # 返回均值
    return tf.reduce_mean(square_error)

bbox_pred = tf.random_uniform([2,4],10,100,seed = 100)
bbox_target = tf.random_uniform([2,4],15,150,seed = 100)
with tf.Session() as sess:
    print("cls_prob:%s"%(sess.run(bbox_pred)))
label = np.array([1,0])
bbox_ohem(bbox_pred,bbox_target,label)

在这里插入图片描述

landmark_ohem:作用就是返回landmark的损失,用的是landmark样本。

def landmark_ohem(landmark_pred,landmark_target,label):
    '''

    :param landmark_pred:
    :param landmark_target:
    :param label:
    :return: mean euclidean loss
    '''
    #keep label =-2  then do landmark detection
    ones = tf.ones_like(label,dtype=tf.float32)
    zeros = tf.zeros_like(label,dtype=tf.float32)
    valid_inds = tf.where(tf.equal(label,-2),ones,zeros)
    square_error = tf.square(landmark_pred-landmark_target)
    square_error = tf.reduce_sum(square_error,axis=1)
    num_valid = tf.reduce_sum(valid_inds)
    #keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
    keep_num = tf.cast(num_valid, dtype=tf.int32)
    square_error = square_error*valid_inds
    _, k_index = tf.nn.top_k(square_error, k=keep_num)
    square_error = tf.gather(square_error, k_index)
    return tf.reduce_mean(square_error)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139359.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月29日 下午11:00
下一篇 2022年5月29日 下午11:00


相关推荐

  • ubuntu安装增强功能失败_ubuntu参考的对象不支持

    ubuntu安装增强功能失败_ubuntu参考的对象不支持Ubuntu换源后,更新提示GPGerror缺少公钥W:GPGerror:http://mirrors.aliyun.comtrusty-securityInRelease:Thefollowingsignaturescouldn’tbeverifiedbecausethepublickeyisnotavailable:NO_PUBKEY40976EAF437D05B5NO_PUBKEY3B4FE6ACC0B21F32W:GPGerror:http:/

    2022年8月31日
    5
  • SD卡中FAT32文件格式高速入门(图文具体介绍)

    SD卡中FAT32文件格式高速入门(图文具体介绍)

    2021年12月6日
    124
  • 安全帽识别软件使用中常见问题分析[通俗易懂]

    安全帽识别软件使用中常见问题分析[通俗易懂]一、安全帽识别软件的主要功能是什么?安全帽识别是通俗的说法,相对准确的名称应该是安全帽佩戴检测,是用深度学习的算法对视频流进行分析,通过人工智能来判断视频中的人是否未佩戴安全帽,如果未佩戴,则触发告警规则。二、安全帽识别软件的技术成熟吗?2012年人工智能领域的卷积神经网络迎来重大突破,深圳强美随即将此尖端技术应用于工业安全监控,因为掌握海量样本数据的先天优势,鹰眸安全帽(佩戴检测)识别系…

    2022年5月19日
    57
  • golang []byte和string相互转换

    golang []byte和string相互转换测试例子:packagemainimport(“fmt”)funcmain(){str2:=”hello”data2:=[]byte(str2)fmt.Println(data2)str2=string(data2[:])fmt.Println(str2)}

    2022年6月17日
    28
  • 亲身经历从软通外包到华为OD,两者有什么区别?「建议收藏」

    亲身经历从软通外包到华为OD,两者有什么区别?「建议收藏」亲身经历从软通外包到华为OD,两者有什么区别?声明:本人所有言论仅限2021-04当前真实所在的部门情况。序言​ 坐标南京,本人2014年毕业于211本科院校,16年底加入软通动力,20年初转入华为OD。到如今算是经历了完整的OD模式。从被华为沟通加入OD,尝试第一次了解它的时候开始,网评就在外包/OD/自有三者之间疯狂比较。那本文就从外包/OD到底有什么区别?OD离自有有多远?来说说在如今华为社招基本停工的局面下,该如何面对华为这个ICT巨兽的招聘?希望对有些迷茫的人提供一些帮助。网上的声音有很多,

    2022年7月17日
    80
  • C语言总结(一维数组、二维数组、字符数组和字符串)

    C语言总结(一维数组、二维数组、字符数组和字符串)C 语言总结第七章 数组一维数组一维数组的定义一维数组的引用一维数组的初始化程序举例二维数组及多维数组二维数组的定义二维数组元素的引用二维数组元素的初始化程序举例字符数组和字符串字符数组第七章 数组数组是构造数据类型之一数组 有序数据的集合 用数组名标识元素 属同一数据类型 用数组名和下标确定一维数组一维数组的定义定义方式 例 inta 6 一维数组的引用 1 数组

    2026年3月26日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号