tensorflow模型查看参数(pytorch conv2d函数详解)

tf.nn.conv2d()参数解析

大家好,又见面了,我是你们的朋友全栈君。

定义:
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
功能:将两个4维的向量input(样本数据矩阵)和filter(卷积核)做卷积运算,输出卷积后的矩阵
input的形状:[batch, in_height ,in_width, in_channels]
batch: 样本的数量
in_height :每个样本的行数
in_width: 每个样本的列数
in_channels:每个样本的通道数,如果是RGB图像就是3
filter的形状:[filter_height, filter_width, in_channels, out_channels]
filter_height:卷积核的高
filter_width:卷积核的宽
in_channels:输入的通道数
out_channels:输出的通道数
比如在tensorflow的cifar10.py文件中有句:
这里写图片描述
卷积核大小为 5*5,输入通道数是3,输出通道数是64,即这一层输出64个特征
在看cifar10.py里第二层卷积核的定义:
这里写图片描述
大小依然是5*5,出入就是64个通道即上一层的输出,输出依然是64个特征
strides:[1,stride_h,stride_w,1]步长,即卷积核每次移动的步长
padding:填充模式取值,只能为”SAME”或”VALID”
卷积或池化后的节点数计算公式:
output_w = int((input_w + 2*padding – filter_w)/strid_w) + 1
举例说明:
假设这里使用的图像每副只有一行像素一通道,共3副图像

>>> a = np.array([[1,1,1],[2,2,2],[3,3,3]])
>>> b=tf.reshape(a,[a.shape[0],1,a.shape[1],1])
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(b)
array([[[[1], [1], [1]]], [[[2], [2], [2]]], [[[3], [3], [3]]]])

然后设有2个1*2的卷积核

>>> k=tf.constant([[[[ 1.0, 1.0]],[[2.0, 2.0]]]], dtype=tf.float32)
>>> mycov=tf.nn.conv2d(b, k, [1, 1, 1, 1], padding='SAME')
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(mycov)
array([[[[ 3., 3.], [ 3., 3.], [ 1., 1.]]], [[[ 6., 6.], [ 6., 6.], [ 2., 2.]]], [[[ 9., 9.], [ 9., 9.], [ 3., 3.]]]], dtype=float32)
>>> sess.run(b)
array([[[[ 1.], [ 1.], [ 1.]]], [[[ 2.], [ 2.], [ 2.]]], [[[ 3.], [ 3.], [ 3.]]]], dtype=float32)
>>> sess.run(k)
array([[[[ 1., 1.]], [[ 2., 2.]]]], dtype=float32)

这里写图片描述
最后的0是函数自动填充的,所以最后就得到了一个2通道的卷积结果
将k改成[[ 1.0, 0.5],[2, 1]]然后再次运行:

>>> k=tf.constant([[[[ 1.0, 0.5]],[[2, 1]]]], dtype=tf.float32)
>>> mycov=tf.nn.conv2d(b, k, [1, 1, 1, 1], padding='SAME')
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(mycov)
array([[[[ 3. , 1.5], [ 3. , 1.5], [ 1. , 0.5]]], [[[ 6. , 3. ], [ 6. , 3. ], [ 2. , 1. ]]], [[[ 9. , 4.5], [ 9. , 4.5], [ 3. , 1.5]]]], dtype=float32)

卷积核一般用tf.get_variable()初始化,这里为了演示直接指定为常量

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/129729.html原文链接:https://javaforall.net

(0)
上一篇 2022年4月14日 下午8:40
下一篇 2022年4月14日 下午8:40


相关推荐

  • mysql 时区设定_mysql的时区设置「建议收藏」

    mysql 时区设定_mysql的时区设置「建议收藏」IDEA配置mysql数据库时,地址,用户名,密码,数据库名填写之后,点测试连接,提示Serverreturnsinvalidtimezone.Goto’Advanced’tabandset’serverTimezone’prope如图翻译过来就是:服务器返回无效时区。进入“高级”选项卡,手动设置“serverTimezone”属性。网上查询了一下解决方案,原来是要设置时区…

    2025年7月4日
    3
  • TRILL技术——实现二层多路径转发

    TRILL技术——实现二层多路径转发TRILL TransparentI 多链路透明互联 是 IETF 为实现数据中心大二层扩展制定的一个标准 目前已经有一些协议文稿标准化 如 RFC6325 6326 6327 等等 该协议的核心思想是将成熟的三层路由的控制算法引入到二层交换中 将原先的 L2 报文加一个新的封装 隧道封装 转换到新的地址空间上进行转发

    2026年3月16日
    3
  • python turtle 表白_pythonturtle背景颜色

    python turtle 表白_pythonturtle背景颜色python中用turtle画爱心表白运行后的效果图:下面的代码是在python3.7写的,代码有点长,但却语法简单易懂代码如下:importturtlestr=input(‘请输入表白语:’)turtle.speed(10)#画笔速度turtle.setup(1800,700,70,70)turtle.color(‘black’,’pink’)#画笔颜色t…

    2025年9月28日
    3
  • 提问艺术「建议收藏」

    提问艺术「建议收藏」提问的艺术相信大部分老鸟当年都看过这篇经典的文章。在这里在转一次,以帮助大家能更好地问问题,以便获得更好的回答。先贴结论吧最后,不管是谁,来这里回答问题都是凭一腔热忱,凭兴趣和心情,如果版面充斥让人没有兴趣回答的问题,我想,对大家都不是好消息。自力更生真的很重要,不管你水平如何遇到什么样的困难,能自己解决多少就解决多少,然后再来求助,说需要什么什么帮助,多做一些努力只有好处

    2022年6月23日
    28
  • 回调金字塔是什么意思_回调地狱

    回调金字塔是什么意思_回调地狱如果你想阅读体验更好 可以戳链接回调地狱前言从前一文中你真的了解回调我们已知道回调函数是必须得依赖另一个函数执行调用 它是异步执行的 也就是需要时间等待 典型的例子就是 Ajax 应用 比如 http 请求 在不刷新浏览器的情况下 当你执行 DOM 事件时 比如页面上点击某链接 回车等事件操作 浏览器会悄悄向服务端发送若干 http 请求 携带后台可识别的参数 等待服务器响应返回数据 这个过程是异步回调的 当许多

    2026年3月19日
    1
  • Hunyuan-MT-7B快速部署:无需代码基础搭建翻译服务

    Hunyuan-MT-7B快速部署:无需代码基础搭建翻译服务

    2026年3月15日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号