经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性nbsp nbsp 不知不觉 笔者接触 Tensorflow 也满一年了 在这一年当中 笔者对 Tensorflow 的了解程度也逐渐加深 相比笔者接触的第一个深度学习框架 Caffe 而言 笔者认为 Tensorflow 更适合科研一些 网络搭建与算法设置的自由度也更大 使用 Tensorflow 实现自己的算法也更迅速 nbsp nbsp 但是 笔者认为 Tensorflow 还是有不足的地方 第一体现在 Tensorflow 的数据机制 由于 te

   不知不觉,笔者接触Tensorflow也满一年了。在这一年当中,笔者对Tensorflow的了解程度也逐渐加深。相比笔者接触的第一个深度学习框架Caffe而言,笔者认为Tensorflow更适合科研一些,网络搭建与算法设置的自由度也更大,使用Tensorflow实现自己的算法也更迅速。

   但是,笔者认为Tensorflow还是有不足的地方。第一体现在Tensorflow的数据机制,由于tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的。因此,在网络搭建的时候,是不能对tensor进行判值操作的,即不能插入if…else…之类的代码。第二,相较于numpy array,Tensorflow中对tensor的操作接口灵活性并没有那么高,使得Tensorflow的灵活性减弱。

   在笔者使用Tensorflow的一年中积累的编程经验来看,扩展Tensorflow程序的灵活性,有一个重要的手段,就是使用tf.py_func接口。笔者先对这个接口做出解析:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   在上图中,我们看到,tf.py_func的核心是一个func函数(由用户自己定义),该函数接收numpy array作为输入,并返回numpy array类型的输出。看到这里,大家应该能够明白为什么建议使用py_func,因为在func函数中,可以对转化成numpy array的tensor进行np.运算,这就大大扩展了程序的灵活性。

   然后,我们来看看tf.py_func接受什么参数:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   在使用tf.py_func的过程中,主要核心是使用前三个参数。

   第一个参数func,也是最重要的,是一个用户自定制的函数,输入numpy array,输出也是numpy array,在该函数中,可以自由使用np.操作。

   第二个参数inp,是func函数接收的输入,是一个列表

   第三个参数Tout,指定了func函数返回的numpy array转化成tensor后的格式,如果是返回个值,就是一个列表或元组;如果只有个返回值,就是一个单独的dtype类型(当然也可以用列表括起来)。

   最后来看看tf.py_func的输出:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   输出是一个tensor列表或单个tensor。

   到这里,tf.py_func的原理也就逐渐明晰了。首先,tf.py_func接收的是tensor,然后将其转化为numpy array送入func函数,最后再将func函数输出的numpy array转化为tensor返回。

   在使用过程中,有两个需要注意的地方,第一就是func函数的返回值类型一定要和Tout指定的tensor类型一致。第二就是,如下图所示,tf.py_func中的func是脱离Graph的。在func中不能定义可训练的参数参与网络训练(反传)。

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   上面就解析了tf.py_func的使用方法和原理。下面笔者举几个例子,一是向大家展示tf.py_func带来的灵活性,二是通过笔者的亲身体会说明一下如何使用tf.py_func完成一些Tensorflow基础编程中较难的任务。

1) tf.py_func在Faster R-CNN中的接口中的使用。

   在目标检测算法Faster R-CNN中,需要计算各种ground truth,接口比较复杂。因此,使用tf.py_func是一个比较好的途径。对于tf.py_func的使用,可以参见计算RPN的ground truth和计算proposals的ground truth时的使用方法。可以看到,都是将tensor转化成numpy array,再使用np.操作完成复杂运算。

   下面笔者来举两个小例子,说明一下tf.py_func的强大功能。

2) 使用tf.py_func获得未知tensor维度。

   大家知道,我们在做数据占位的时候,可能会传入”None”,即不知道数据的该维大小,取决于feed_dict中的实际值。可是,在运算中,要使用到数据的该维大小时应该怎么办呢?比如下面这个例子:

import tensorflow as tf import numpy as np def main(): a = tf.placeholder(tf.float32, shape=[1, 2], name = "tensor_a") b = tf.placeholder(tf.float32, shape=[None, 2], name = "tensor_b") tile_a = tf.tile(a, [b.get_shape()[0], 1]) sess = tf.Session() array_a = np.array([[1., 2.]]) array_b = np.array([[3., 4.],[5., 6.],[7., 8.]]) feed_dict = {a: array_a, b: array_b} tile_a_value = sess.run(tile_a, feed_dict = feed_dict) print(tile_a_value) if __name__ == '__main__': main()

   如上代码所示,要完成一个很简单的功能,就是扩张tensor a,将其的维度变成和tensor b一样,可是tensor b的维度暂时未知。我们来看看,执行上述程序能得到什么结果:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   可以看到,由于tensor b第一个维度未知,因此在给tile_a分配存储空间时报错,提示不能有None存在。

   如何解决这个问题?稍微改写一下上述代码,让tensor扩张在tf.py_func中执行:

import tensorflow as tf import numpy as np from py_func_1 import * def main(): a = tf.placeholder(tf.float32, shape=[1, 2], name = "tensor_a") b = tf.placeholder(tf.float32, shape=[None, 2], name = "tensor_b") tile_a = tile_tensor(a, b) sess = tf.Session() array_a = np.array([[1., 2.]]) array_b = np.array([[3., 4.],[5., 6.],[7., 8.]]) feed_dict = {a: array_a, b: array_b} tile_a_value = sess.run(tile_a, feed_dict = feed_dict) print(tile_a_value) if __name__ == '__main__': main()

   在上面的代码中,tensor扩张在tile_tensor这个函数中执行。该函数定义在py_func_1.py文件中,下面是py_func_1.py的代码:

import tensorflow as tf import numpy as np def tile_tensor(tensor_a, tensor_b): tile_tensor_a = tf.py_func(_tile_tensor, [tensor_a, tensor_b], tf.float32) return tile_tensor_a def _tile_tensor(a, b): tile_a = np.tile(a, (b.shape[0], 1)) return tile_a

   大家可以看到,使用了tf.py_func接口,参数func就是_tile_tensor函数。在
_tile_tensor函数中,将a扩张了,执行一下修改后的main函数,输出结果:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   大家可以看到,在tile_tensor函数中,tensor a在tensor b的维度未知的情况下,根据tensor b的实际维度([3, 2])将其扩张了。并返回了一个tensor类型的tile_a。

3) 在tf.py_func中对tensor的值作出判断。

   笔者在之前的博客中提到过,在tf.Session().run之前,是不能对Tensor的值做出判断的。比如,我们想根据tensor a的值对tensor b做出扩张:

import tensorflow as tf import numpy as np def main(): a = tf.placeholder(tf.float32, shape=[1], name = "tensor_a") b = tf.placeholder(tf.float32, shape=[1, 2], name = "tensor_b") tile_b = b if a[0]==1.: tile_b = tf.tile(b, [4, 1]) sess = tf.Session() array_a = np.array([1.]) array_b = np.array([[2., 3.]]) feed_dict = {a: array_a, b: array_b} tile_b_value = sess.run(tile_b, feed_dict = feed_dict) print(tile_b_value) if __name__ == '__main__': main()

   如果a[0]的值为1.0,那么就将tensor b扩张四倍。我们执行一下上述代码看看结果:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   大家可以看到,由于在if语句执行时,tensor a里面是空的。因此,不会执行if中的语句。尽管在feed_dict中a被填充了1.0,并且程序不报错,可是没有达到预想的目标。

   如何解决这个问题?稍微改写一下上述代码,让判值进行tensor扩张在tf.py_func中执行:

import tensorflow as tf import numpy as np from py_func_2 import * def main(): a = tf.placeholder(tf.float32, shape=[1], name = "tensor_a") b = tf.placeholder(tf.float32, shape=[1, 2], name = "tensor_b") tile_tensor_b = tile_b(a, b) sess = tf.Session() array_a = np.array([1.]) array_b = np.array([[2., 3.]]) feed_dict = {a: array_a, b: array_b} tile_b_value = sess.run(tile_tensor_b, feed_dict = feed_dict) print(tile_b_value) if __name__ == '__main__': main()

   大家可以看到,在py_func_2.py中的tile_b函数中,对tensor b进行了判值扩张。py_func_2.py代码如下所示:

import tensorflow as tf import numpy as np def tile_b(tensor_a, tensor_b): tile_tensor_b = tf.py_func(_tile_b, [tensor_a, tensor_b], tf.float32) return tile_tensor_b def _tile_b(a, b): if a[0]==1.: tile_b = np.tile(b, (4, 1)) else: tile_b = b return tile_b

   大家可以看到,在tile_b函数中有一个tf.py_func函数,其中的func参数便是_tile_b函数。在_tile_b函数中,根据a的值对b进行了扩张。我们来运行一下main函数,输出结果:

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

   tensor b得到了扩张!

   大家可以看到,在tensor输入进tf.py_func并转化成numpy array后,判值操作就有效了。

   通过上面的三个例子,笔者向大家揭示了tf.py_func函数中的神奇之处。大家可以看到,在实际使用中,将tensor转化为numpy array后,能够执行更灵活的操作,达到更多的目标。总而言之,tf.py_func是一个很强大的接口,也希望大家能在Tensorflow程序中灵活运用。


   欢迎阅读笔者后续博客,各位读者朋友的支持与鼓励是我最大的动力!


written by jiong

道阻且长,行则将至

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/224360.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 下午12:06
下一篇 2026年3月17日 下午12:07


相关推荐

  • pycharm远程部署_pycharm 远程调试

    pycharm远程部署_pycharm 远程调试在这之前你要确保服务器上已经创建好虚拟环境你本地已经安装好pycharm1创建本地文件远程服务器上已经有一个文件了。现在你在本地创建一个同名文件。服务器上的虚拟环境为DrQA,所以我在本地新建一个DrQA空文件夹。2用pycharm打开空项目3配置服务器的解释器左上角File→Setting→projectxxx→pythoninterpreter点右上角的小齿轮,然后点add选择SSHInterpreter,然后在上边填上服务器的地址、usernam

    2025年6月29日
    6
  • 利用websocket+Vuex完成一个实时聊天软件(前端部分)

    利用websocket+Vuex完成一个实时聊天软件(前端部分)这篇文章主要利用 websocked 建立长连接 利用 Vuex 全局通信的特性 以及 watch computed 函数实时监听消息变化 展示 实现一个实时聊天平台

    2026年3月26日
    2
  • getrealpath()_成语解释1000个

    getrealpath()_成语解释1000个getRealPath详细解释今天在获取路径的时候突然发现request中也有getRealPath这个方法,最后查了查文档,说是request.getRealPath(“”)不推荐使用,已摈弃。getServlet().getServletContext().getRealPath(“/”);可以取代上者,都是取得应用绝对路径。比如,有个servlet叫UploadServlet,它部署在tomcat下面以后的绝对路径如下:“C:\ProgramFiles\apache-tomcat-8.

    2026年1月26日
    4
  • shell输出数组元素_shell中使用数组

    shell输出数组元素_shell中使用数组数组介绍平时的定义a=1,b=2,c=3,变量如果多了,再一个一个定义很费劲,并且取变量的也费劲简单的说,数组就是相同数据类型的元素按一定顺序排列的集合数组就是把有限个类型相同的变量用一个名字命名,然后用编号区分他们得边合。这个名字成为数组名,编号成为数组下标。组成数组的各个变量成为数组的分称为数组的元素,有时也称为下标变量数组定义与增删改查法1:array=(value1value2valu…

    2025年7月23日
    5
  • http请求报400报错

    http请求报400报错400是HTTP的状态码,主要有两种形式:1、badrequest意思是“错误的请求”;2、invalidhostname意思是“不存在的域名”。在ajax请求后台数据时有时会报HTTP400错误-请求无效(Badrequest);出现这个请求无效报错说明请求没有进入到后台服务里;1、确认发送的数据格式是否正确。调试查看你发送的数据格式是否正确或是否有乱码…

    2022年6月12日
    97
  • eclipse更改maven的本地路径和外部仓库地址

    eclipse更改maven的本地路径和外部仓库地址

    2021年7月20日
    66

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号