ResNet18复现「建议收藏」

ResNet18复现「建议收藏」ResNet18的网络架构图首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。代码如下:importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFclassBasicBlock(nn.Module):def__ini

大家好,又见面了,我是你们的朋友全栈君。

ResNet18的网络架构图

ResNet18复现「建议收藏」

首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。

代码如下:

import torch
from torch import nn
from torch.nn import functional as F

class BasicBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(BasicBlock,self).__init__()
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size,stride,padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        output=self.bn1(self.conv1(x))
        output=self.bn2(self.conv2(output))
        return F.relu(x+output)
    

class BasicDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(BasicDownBlock,self).__init__()     
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size[0],stride[0],padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size[0],stride[1],padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size[1],stride[0])
        self.bn3=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        output=self.bn1(self.conv1(x))
        output=self.bn2(self.conv2(output))
        output1=self.bn3(self.conv3(x))
        return F.relu(output1+output)

class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        # 3x224x224-->64x112x112
        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3)
        self.bn1=nn.BatchNorm2d(64)
        # 64x112x112-->64x56x56
        self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        # 64x56x56-->64x56x56
        self.layer1=nn.Sequential(
            BasicBlock(64,64,3,1),
            BasicBlock(64,64,3,1)
        )
        # 64x56x56-->128*28*28
        self.layer2=nn.Sequential(
            BasicDownBlock(64,128,[3,1],[2,1]),
            BasicBlock(128,128,3,1)
        )
        # 128*28*28-->256*14*14
        self.layer3=nn.Sequential(
            BasicDownBlock(128,256,[3,1],[2,1]),
            BasicBlock(256,256,3,1)
        )
        # 256*14*14-->512x7x7
        self.layer4=nn.Sequential(
            BasicDownBlock(256,512,[7,1],[2,1]),
            BasicBlock(512,512,3,1)
        )
        # 512x7x7-->512x1x1
        self.avgpool=nn.AdaptiveMaxPool2d(output_size=(1,1))
        self.flat=nn.Flatten()
        self.linear=nn.Linear(512,10)
        
    def forward(self,x):
        output=self.pool1(F.relu(self.bn1(self.conv1(x))))
        output=self.layer1(output)
        output=self.layer2(output)
        output=self.layer3(output)
        output=self.layer4(output)
        output=self.avgpool(output)
        output=self.flat(output)
        output=self.linear(output)
        return output
    

net=ResNet18()
x=torch.randn(32,3,224,224)
print(x.shape)
y=net(x)
print(y.shape)

代码中BasicBlock为普通的残差块,注意步长和卷积核的大小,BasicDownBlock为下采样的残差块,然后将四层的网络表示出来,最后进行验证x.shape为torch.Size([32, 3, 224, 224]),y.shape为torch.Size([32, 10])。 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141545.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • MySQL8.0正确修改密码的姿势[通俗易懂]

    MySQL8.0正确修改密码的姿势[通俗易懂]mysql更新完密码,总是拒绝连接、登录失败?MySQL8.0不能通过直接修改mysql.user表来更改密码。正确更改密码的方式备注:清空root密码MySQL8.0不能通过直接修改mysql.user表来更改密码。因为authentication_string字段下只能是MySQL加密后的43位字符串密码,其他的导致错误。错误不报出,但是无法再登录mysql,总是会提示无…

    2022年8月13日
    4
  • free技术详解 lock_lock free的理解

    free技术详解 lock_lock free的理解转自:http://www.isnowfy.com/understand-to-lock-free/以前一直不明白lockfree是什么,后来发现原来是完全理解错了概念,lockfree看到大家有的翻译为无锁,有的翻译为锁无关,其实用不用锁和lockfree是不相关的,用了锁也可能是lockfree,而不用锁有可能不是lockfree。一个lockfree的解释是一个“锁无关”的程序能…

    2022年7月19日
    24
  • 调用ShellExecute所须要头文件

    调用ShellExecute所须要头文件

    2021年12月7日
    56
  • Spark pyspark rdd连接函数之join、leftOuterJoin、rightOuterJoin和fullOuterJoin介绍

    Spark pyspark rdd连接函数之join、leftOuterJoin、rightOuterJoin和fullOuterJoin介绍Sparkpysparkrdd连接函数之join、leftOuterJoin、rightOuterJoin和fullOuterJoin介绍union用于组合两个rdd的元素,join用于内连接,而后三个函数(leftOuterJoin,rightOuterJoin,fullOuterJoin)用于类似于SQL的左、右、全连接。针对key-value形式的RDD。例子:1)数据初始化>&g…

    2025年7月11日
    4
  • Android自定义控件之滑动解锁

    Android自定义控件之滑动解锁代码参考地址https://github.com/liuzhiyuan0932/SlideUnLock代码效果图>自定义滑动解锁的控件继承自ViewpublicclassSlideUnlockViewextendsView自定义SlideUnLockView的属性在values文件夹中定义属性

    2022年6月24日
    26
  • ModelMap的用法

    ModelMap的用法ModelMap对象主要用于传递控制方法处理数据到结果页面,也就是说我们把结果页面上需要的数据放到ModelMap对象中即可,他的作用类似于request对象的setAttribute方法的作用,用来在一个请求过程中传递处理的数据。通过以下方法向页面传递参数:addAttribute(Stringkey,Objectvalue);在页面上可以通过el变量方式$key或者bboss的一系列数…

    2022年6月24日
    49

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号