ResNet18复现「建议收藏」

ResNet18复现「建议收藏」ResNet18的网络架构图首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。代码如下:importtorchfromtorchimportnnfromtorch.nnimportfunctionalasFclassBasicBlock(nn.Module):def__ini

大家好,又见面了,我是你们的朋友全栈君。

ResNet18的网络架构图

ResNet18复现「建议收藏」

首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。

代码如下:

import torch
from torch import nn
from torch.nn import functional as F

class BasicBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(BasicBlock,self).__init__()
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size,stride,padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        output=self.bn1(self.conv1(x))
        output=self.bn2(self.conv2(output))
        return F.relu(x+output)
    

class BasicDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super(BasicDownBlock,self).__init__()     
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size[0],stride[0],padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size[0],stride[1],padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size[1],stride[0])
        self.bn3=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        output=self.bn1(self.conv1(x))
        output=self.bn2(self.conv2(output))
        output1=self.bn3(self.conv3(x))
        return F.relu(output1+output)

class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        # 3x224x224-->64x112x112
        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3)
        self.bn1=nn.BatchNorm2d(64)
        # 64x112x112-->64x56x56
        self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        # 64x56x56-->64x56x56
        self.layer1=nn.Sequential(
            BasicBlock(64,64,3,1),
            BasicBlock(64,64,3,1)
        )
        # 64x56x56-->128*28*28
        self.layer2=nn.Sequential(
            BasicDownBlock(64,128,[3,1],[2,1]),
            BasicBlock(128,128,3,1)
        )
        # 128*28*28-->256*14*14
        self.layer3=nn.Sequential(
            BasicDownBlock(128,256,[3,1],[2,1]),
            BasicBlock(256,256,3,1)
        )
        # 256*14*14-->512x7x7
        self.layer4=nn.Sequential(
            BasicDownBlock(256,512,[7,1],[2,1]),
            BasicBlock(512,512,3,1)
        )
        # 512x7x7-->512x1x1
        self.avgpool=nn.AdaptiveMaxPool2d(output_size=(1,1))
        self.flat=nn.Flatten()
        self.linear=nn.Linear(512,10)
        
    def forward(self,x):
        output=self.pool1(F.relu(self.bn1(self.conv1(x))))
        output=self.layer1(output)
        output=self.layer2(output)
        output=self.layer3(output)
        output=self.layer4(output)
        output=self.avgpool(output)
        output=self.flat(output)
        output=self.linear(output)
        return output
    

net=ResNet18()
x=torch.randn(32,3,224,224)
print(x.shape)
y=net(x)
print(y.shape)

代码中BasicBlock为普通的残差块,注意步长和卷积核的大小,BasicDownBlock为下采样的残差块,然后将四层的网络表示出来,最后进行验证x.shape为torch.Size([32, 3, 224, 224]),y.shape为torch.Size([32, 10])。 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141545.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • [职场]最近聊到30岁以上的程序员,该何去何从了?你有啥想法?

    [职场]最近聊到30岁以上的程序员,该何去何从了?你有啥想法?

    2022年2月19日
    41
  • 什么是多模态学习?

    什么是多模态学习?首先,什么叫做模态(Modality)呢?每一种信息的来源或者形式,都可以称为一种模态。例如,人有触觉,听觉,视觉,嗅觉;信息的媒介,有语音、视频、文字等;多种多样的传感器,如雷达、红外、加速度计等。以上的每一种都可以称为一种模态。同时,模态也可以有非常广泛的定义,比如我们可以把两种不同的语言当做是两种模态,甚至在两种不同情况下采集到的数据集,亦可认为是两种模态。因此,多模态机器学习,…

    2022年6月16日
    32
  • Microsoft Office 2007 中文专业版密钥

    Microsoft Office 2007 中文专业版密钥MicrosoftOffice2007中文专业版(微软原版)正版密钥MicrosoftOfficeVisio2007简体中文专业版:简介:    便于IT和商务专业人员就复杂信息、系统和流程进行可视化处理、分析和交流。使用具有专业外观的OfficeVisio2007图表,可以促进对系统和流程的了解,深入了解复杂信息并利用这些知识做出更好的业务决策。迅雷下载    …

    2022年7月19日
    18
  • android之存储篇_SQLite存储方式「建议收藏」

    SQLite是一种转为嵌入式设备设计的轻型数据库,其只有五种数据类型,分别是:    NULL: 空值    INTEGER: 整数    REAL: 浮点数    TEXT: 字符串    BLOB: 大数据  在SQLite中,并没有专门设计BOOLEAN和DATE类型,因为BOOLEAN型可以用INTEGER的0和1代替true和false,而DATE类型则可以拥有特

    2022年3月10日
    35
  • MYSQL8.0以上版本正确修改ROOT密码[通俗易懂]

    MYSQL8.0以上版本正确修改ROOT密码[通俗易懂]部署环境:安装版本redhatCent7.0MYSQL版本8.0.2.0成功部署完毕后出现故障情况:1.正常启动MYSQL服务后,敲Linux中root账户和密码进入不去。2.从/etc/my.cnf配置文件中加入skip-grant-table后正常登陆,但是不能创建用户等多操作总结来说:想进去mysql后不能操作多指令,操作多指令又不能进去mysql,死…

    2022年5月6日
    93
  • pycharm virtualenv和conda_pycharm文件名红色

    pycharm virtualenv和conda_pycharm文件名红色from: http://www.cnblogs.com/IDRI/p/6354237.htmlLinux:启动虚拟环境:sourceenv/bin/activate Windows:pipinstallvirtualenv创建虚拟环境目录env激活虚拟环境:C:\Python27\Scripts

    2022年8月28日
    0

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号