tensorflow模型查看参数(pytorch conv2d函数详解)

tf.nn.conv2d()参数解析

大家好,又见面了,我是你们的朋友全栈君。

定义:
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
功能:将两个4维的向量input(样本数据矩阵)和filter(卷积核)做卷积运算,输出卷积后的矩阵
input的形状:[batch, in_height ,in_width, in_channels]
batch: 样本的数量
in_height :每个样本的行数
in_width: 每个样本的列数
in_channels:每个样本的通道数,如果是RGB图像就是3
filter的形状:[filter_height, filter_width, in_channels, out_channels]
filter_height:卷积核的高
filter_width:卷积核的宽
in_channels:输入的通道数
out_channels:输出的通道数
比如在tensorflow的cifar10.py文件中有句:
这里写图片描述
卷积核大小为 5*5,输入通道数是3,输出通道数是64,即这一层输出64个特征
在看cifar10.py里第二层卷积核的定义:
这里写图片描述
大小依然是5*5,出入就是64个通道即上一层的输出,输出依然是64个特征
strides:[1,stride_h,stride_w,1]步长,即卷积核每次移动的步长
padding:填充模式取值,只能为”SAME”或”VALID”
卷积或池化后的节点数计算公式:
output_w = int((input_w + 2*padding – filter_w)/strid_w) + 1
举例说明:
假设这里使用的图像每副只有一行像素一通道,共3副图像

>>> a = np.array([[1,1,1],[2,2,2],[3,3,3]])
>>> b=tf.reshape(a,[a.shape[0],1,a.shape[1],1])
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(b)
array([[[[1], [1], [1]]], [[[2], [2], [2]]], [[[3], [3], [3]]]])

然后设有2个1*2的卷积核

>>> k=tf.constant([[[[ 1.0, 1.0]],[[2.0, 2.0]]]], dtype=tf.float32)
>>> mycov=tf.nn.conv2d(b, k, [1, 1, 1, 1], padding='SAME')
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(mycov)
array([[[[ 3., 3.], [ 3., 3.], [ 1., 1.]]], [[[ 6., 6.], [ 6., 6.], [ 2., 2.]]], [[[ 9., 9.], [ 9., 9.], [ 3., 3.]]]], dtype=float32)
>>> sess.run(b)
array([[[[ 1.], [ 1.], [ 1.]]], [[[ 2.], [ 2.], [ 2.]]], [[[ 3.], [ 3.], [ 3.]]]], dtype=float32)
>>> sess.run(k)
array([[[[ 1., 1.]], [[ 2., 2.]]]], dtype=float32)

这里写图片描述
最后的0是函数自动填充的,所以最后就得到了一个2通道的卷积结果
将k改成[[ 1.0, 0.5],[2, 1]]然后再次运行:

>>> k=tf.constant([[[[ 1.0, 0.5]],[[2, 1]]]], dtype=tf.float32)
>>> mycov=tf.nn.conv2d(b, k, [1, 1, 1, 1], padding='SAME')
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(mycov)
array([[[[ 3. , 1.5], [ 3. , 1.5], [ 1. , 0.5]]], [[[ 6. , 3. ], [ 6. , 3. ], [ 2. , 1. ]]], [[[ 9. , 4.5], [ 9. , 4.5], [ 3. , 1.5]]]], dtype=float32)

卷积核一般用tf.get_variable()初始化,这里为了演示直接指定为常量

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/129729.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Negative Sampling 负采样详解[通俗易懂]

    Negative Sampling 负采样详解[通俗易懂]在word2vec中,为了简化训练的过程,经常会用到NegativeSampling负采样这个技巧,这个负采样到底是怎么样的呢?之前在我的博文word2vec算法理解和数学推导中对于word2vec有了很详细的数学推导,这里主要讲解一下负采样是如何降低word2vec的复杂度的。首先我们直接写出word2vec的目标函数,假设有一句话:query=w1,w2,w3,..,wnquery=…

    2022年6月26日
    85
  • MySQL详细学习教程(建议收藏)

    MySQL详细学习教程(建议收藏)目录1、初识数据库1.1、什么是数据库1.2、数据库分类1.3、相关概念1.4、MySQL及其安装1.5、基本命令2、操作数据库2.1、操作数据库2.2、数据库的列类型2.3、数据库的字段属性2.4、创建数据库表2.5、数据库存储引擎2.6、修改数据库3、MySQL数据管理3.1、外键3.2、DML语言1.添加insert2.修改update3.删除delete4、DQL查询数据4.1、基础查询4.2、条件查询4.3、分组查询4.4、连接查询4.5、排序和分页4.6、子查询4.7、MySQL函

    2022年10月3日
    2
  • Mysql登录时报错 ERROR 1045 (28000): 错误解决办法

    Mysql登录时报错 ERROR 1045 (28000): 错误解决办法本文转载自:http://www.cnblogs.com/zlslch/p/5937784.html错误问题的描述: ERROR1045(28000):Accessdeniedforuser’ODBC’@’localhost'(usingpassword:NO)ERROR1045(28000):Accessdeniedforuser’ODBC’

    2022年6月4日
    29
  • linux图形界面扩容lvm,linux下对LVM扩容

    linux图形界面扩容lvm,linux下对LVM扩容操作环境:VirtualBox下RedHat6.464位版本扩容原因:/dev/vg_rhel64/lv_root占用率达到100%,导致部分应用无法继续运行操作过程:1.关闭系统,在虚拟机中添加一块10G的磁盘2.查看新添加磁盘对应的名称[root@rhel64~]#fdisk-cul得知磁盘对应名称为/dev/sdc3.给sdc分区,只分一个区sdc1[root@rhel64…

    2022年6月20日
    22
  • pycharm 2021.11激活码【2021最新】

    (pycharm 2021.11激活码)2021最新分享一个能用的的激活码出来,希望能帮到需要激活的朋友。目前这个是能用的,但是用的人多了之后也会失效,会不定时更新的,大家持续关注此网站~IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.net/100143.html…

    2022年3月29日
    61
  • Oracle 动态SQL「建议收藏」

    Oracle 动态SQL「建议收藏」Oracle动态SQL一、动态SQL的简介1、定义静态SQL是指直接嵌入到PL/SQL块中的SQL语句。动态SQL是指运行PL/SQL块是动态输入的SQL语句。2、适用范围如果在PL/SQL块中需要执行DDL语句(create,alter,drop等)、DCL语句(grant,revoke等)或更加灵活的SQL语句,需要用到动态SQL。

    2022年6月23日
    29

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号