特征金字塔池化

特征金字塔池化1 特征金字塔池化如上图所示 将特征图的所有像素划分为 n nn timesnn n 个网格 并将其经过核为 n nn timesnn n 步长为 n n n n n n 的池化 可以选择最大池化或者平均池化 经过较为密集的池化 4 times4 形成 形成 形成 N timesC timesn timesn 的特征图 将其串联形成的特征图 将其串联形成的特征图 将其串联形成 C times 的格式 之后 在第 2 个步骤池化得到的特征图的基础上 取不同的 nnn 值 进行下一个池

1. 特征金字塔池化

特征金字塔池化

如上图所示:

  1. 将特征图的所有像素划分为 n × n n\times n n×n个网格,对每个网格进行池化,池化层的核大小即为网格大小,宽度不符合时可以padding
  2. 取不同的n值,重复1过程;
  3. 将上述过程得到的所有结果经过flatten和concat,得到 C × N C\times N C×N格式的特征图,可以直接用于全连接。

输出的结果只与 n n n值和通道数量相关,而与输入Tensor的形状无关(当然不能太小,否则池化结果为0)

2. 实现

完整代码连接:古承风的gitee

以下是核心代码


def _spp_layer(self,x:torch.Tensor,mode='max',grid_nums:list=[16]):
        """ output_num denote an grid's width steps: --- 1. compute width for specific output_num, sqrt(num) 2. compute pooling's kernel_size and stride 3. pooling 4. concat all the output """
        N,C,H,W = x.size()
        for i in range(len(grid_nums)):
            # step1
            
            h = ceil(H/(sqrt(grid_nums[i])))
            w = ceil(W/(sqrt(grid_nums[i])))
            
            h_pad = int(((h*sqrt(grid_nums[i])+1)-H)/2)
            w_pad = int(((w*sqrt(grid_nums[i])+1)-W)/2)
            # step2
            if mode == "max":
                pool = nn.MaxPool2d(kernel_size=(h,w),stride=(h,w),padding=(h_pad,w_pad)) 
            elif mode=='avg':
                pool = nn.AvgPool2d(kernel_size=(h,w),stride=(h,2),padding=(h_pad,w_pad))
            else:
                raise ValueError(f"{ 
      mode} mode type error ,expect 'max' and 'avg'")
            
            temp = pool(x) # to origin x , means pyramid pooling
            
            # if for fully connected , could use this concat method
            if i == 0:
                output = temp.view(N,-1)
            else:
                output = torch.concat((output,temp.view(N,-1)),-1)

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/232170.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号