pytorch 自定义卷积核进行卷积操作[通俗易懂]

pytorch 自定义卷积核进行卷积操作[通俗易懂]一卷积操作:在pytorch搭建起网络时,大家通常都使用已有的框架进行训练,在网络中使用最多就是卷积操作,最熟悉不过的就是torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True)通过上面的输入发现想自定义自己的卷积核,比如高斯…

大家好,又见面了,我是你们的朋友全栈君。

一 卷积操作:在pytorch搭建起网络时,大家通常都使用已有的框架进行训练,在网络中使用最多就是卷积操作,最熟悉不过的就是

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

通过上面的输入发现想自定义自己的卷积核,比如高斯核,发现是行不通的,因为上面的参数里面只有卷积核尺寸,而权值weight是通过梯度一直更新的,是不确定的。

二  需要自己定义卷积核的目的:目前是需要通过一个VGG网络提取特征特后需要对其进行高斯卷积,卷积后再继续输入到网络中训练。

三 解决方案。使用

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

pytorch 自定义卷积核进行卷积操作[通俗易懂]

 

这里注意下weight的参数。与nn.Conv2d的参数不一样

可以发现F.conv2d可以直接输入卷积的权值weight,也就是卷积核。那么接下来就要首先生成一个高斯权重了。这里不直接一步步写了,直接输入就行。

kernel = [[0.03797616, 0.044863533, 0.03797616],
         [0.044863533, 0.053, 0.044863533],
         [0.03797616, 0.044863533, 0.03797616]]

四 完整代码

class GaussianBlur(nn.Module):
    def __init__(self):
        super(GaussianBlur, self).__init__()
        kernel = [[0.03797616, 0.044863533, 0.03797616],
                  [0.044863533, 0.053, 0.044863533],
                  [0.03797616, 0.044863533, 0.03797616]]
        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        self.weight = nn.Parameter(data=kernel, requires_grad=False)

    def forward(self, x):
        x1 = x[:, 0]
        x2 = x[:, 1]
        x3 = x[:, 2]
        x1 = F.conv2d(x1.unsqueeze(1), self.weight, padding=2)
        x2 = F.conv2d(x2.unsqueeze(1), self.weight, padding=2)
        x3 = F.conv2d(x3.unsqueeze(1), self.weight, padding=2)
        x = torch.cat([x1, x2, x3], dim=1)
        return x

 这里为了网络模型需要写成了一个类,这里假设输入的x也就是经过网络提取后的三通道特征图(当然不一定是三通道可以是任意通道)

如果是任意通道的话,使用torch.expand()向输入的维度前面进行扩充。如下:

    def blur(self, tensor_image):
        kernel = [[0.03797616, 0.044863533, 0.03797616],
               [0.044863533, 0.053, 0.044863533],
               [0.03797616, 0.044863533, 0.03797616]]
       
        min_batch=tensor_image.size()[0]
        channels=tensor_image.size()[1]
        out_channel=channels
        kernel = torch.FloatTensor(kernel).expand(out_channel,channels,3,3)
        self.weight = nn.Parameter(data=kernel, requires_grad=False)

        return F.conv2d(tensor_image,self.weight,1,1)

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/140363.html原文链接:https://javaforall.net

(1)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Http通过header传递参数_http contenttype

    Http通过header传递参数_http contenttype提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录前言一、header常用指令header分为三部分:发送一个200正常响应set404header:页面没找到页面被永久删除,可以告诉搜索引擎更新它们的urls访问受限服务器错误重定向到一个新的位置延迟一段时间后重定向覆盖X-Powered-Byvalue内容语言(en=English)最后修改时间(在缓存的时候可以用到)告诉浏览器要获取的内容还没有更新设置内容的长度(缓存的时候可以用到):用来下载文件:禁止

    2022年8月24日
    5
  • PhpStorm 2021.5 激活码[在线序列号]

    PhpStorm 2021.5 激活码[在线序列号],https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月18日
    39
  • ACM计算几何篇_acm数学

    ACM计算几何篇_acm数学1前言1.1计算几何算法1.2计算几何题目特点及要领1.3预备知识2凸包2.1定义2.1.1凸多边形2.1.2凸包2.2颜料配色问题2.2.1问题描述2.2.2问题简化2.2.3问题抽象2.2.4数学抽象2.2.4.1ConvexCombinationAndAffineCombination2.2.4.2区别与联系…

    2022年10月23日
    1
  • android源码学习-目录「建议收藏」

    android源码学习-目录「建议收藏」一年中感觉进步了不少,现在看android源码已经没有当初那么吃力了。但是和其他开发者的接触过程中,感觉自己对源码的了解还不是很透彻。android不需要所有的源码都去了解,但是几个重要的点的源码还是有必要理解清楚的。自己列了一个表,会去挨个的学习。1.android源码学习-事件分发处理机制2.android源码学习-View绘制流程3.android源码学习-activi…

    2022年6月5日
    23
  • 差分曼彻斯特编码详解「建议收藏」

    差分曼彻斯特编码详解「建议收藏」1.确定开始部位:第一个编码为0,表示从低到高第一个编码为1,表示从高到低;每一位由下面代替,表示信号的波动2.其次,下一位编码,遇0则跳动,遇1则不跳动

    2025年7月30日
    0
  • 1.2线性代数-行列式的性质

    行列式的性质:性质1:;行列式转置值不变对行成立的性质,对列也成立性质二:两行互换(两列互换),行列式的值要变号证明思路:若D中的每一项都和D1中的每一项差一个负号,那么D=-D13214是1234经过一次顺序变换得来的(1和3变换位置),1234为偶,3214肯定是奇原因:2,7,12,13列标的排法没变,只是行标变了。原来是1-2-3-4行,现在变成了3-2-1-4推论:两行或者两列对应相等,行列式值等于0若第一行和第三行互换,那么根据…

    2022年4月9日
    41

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号