pytorch中resnet_通过Pytorch实现ResNet18

pytorch中resnet_通过Pytorch实现ResNet18对于像我这样刚刚入门深度学习的同学来说,可能接触学习了一个开发工具,却没有通过运用来熟练的掌握它。而ResNet是深度学习里面一个非常重要的backbone,并且ResNet18实现起来又足够简单,所以非常适合拿来练手。我们这里的开发环境是:python3.6.10pytorch1.5.0torchvision0.6.0cudatoolkit10.2.89cudnn7.6.5首先,我们需…

大家好,又见面了,我是你们的朋友全栈君。

对于像我这样刚刚入门深度学习的同学来说,可能接触学习了一个开发工具,却没有通过运用来熟练的掌握它。而ResNet是深度学习里面一个非常重要的backbone,并且ResNet18实现起来又足够简单,所以非常适合拿来练手。

我们这里的开发环境是:

python 3.6.10

pytorch 1.5.0

torchvision 0.6.0

cudatoolkit 10.2.89

cudnn 7.6.5

首先,我们需要明确ResNet18的网络结构。在我自己学习的一开始,我对于ResNet的ShortCut机制的实现不是很清楚,当你知道怎么实现这个机制之后,那么剩下的部分也就没有什么挑战了。

论文中,ResNet各种层数的结构如下:pytorch中resnet_通过Pytorch实现ResNet18

我们观察,实际可以将ResNet18分成6个部分:

1. Conv1:也就是第一层卷积,没有shortcut机制。

2. Conv2:第一个残差块,一共有2个。

3. Conv3:第二个残差块,一共有2个。

4. Conv4:第三个残差块,一共有2个。

5. Conv5:第四个残差块,一共有2个。

6. fc:全连阶层。pytorch中resnet_通过Pytorch实现ResNet18

明确这些部分之后,我们就可以开始着手实现啦!

首先,咱们实现残差块:

import torch

import torch.nn as nn

import torch.nn.functionl as F

#定义残差块ResBlock

class ResBlock(nn.Module):

def __init__(self, inchannel, outchannel, stride=1):

super(ResBlock, self).__init__()

#这里定义了残差块内连续的2个卷积层

self.left = nn.Sequential(

nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),

nn.BatchNorm2d(outchannel),

nn.ReLU(inplace=True),

nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),

nn.BatchNorm2d(outchannel)

)

self.shortcut = nn.Sequential()

if stride != 1 or inchannel != outchannel:

#shortcut,这里为了跟2个卷积层的结果结构一致,要做处理

self.shortcut = nn.Sequential(

nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),

nn.BatchNorm2d(outchannel)

)

def forward(self, x):

out = self.left(x)

#将2个卷积层的输出跟处理过的x相加,实现ResNet的基本结构

out = out + self.shortcut(x)

out = F.relu(out)

return out

接着,我们实现ResNet18模型:

class ResNet(nn.Module):

def __init__(self, ResBlock, num_classes=10):

super(ResNet, self).__init__()

self.inchannel = 64

self.conv1 = nn.Sequential(

nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),

nn.BatchNorm2d(64),

nn.ReLU()

)

self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)

self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)

self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)

self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)

self.fc = nn.Linear(512, num_classes)

#这个函数主要是用来,重复同一个残差块

def make_layer(self, block, channels, num_blocks, stride):

strides = [stride] + [1] * (num_blocks – 1)

layers = []

for stride in strides:

layers.append(block(self.inchannel, channels, stride))

self.inchannel = channels

return nn.Sequential(*layers)

def forward(self, x):

#在这里,整个ResNet18的结构就很清晰了

out = self.conv1(x)

out = self.layer1(out)

out = self.layer2(out)

out = self.layer3(out)

out = self.layer4(out)

out = F.avg_pool2d(out, 4)

out = out.view(out.size(0), -1)

out = self.fc(out)

return out

到此,一个ResNet18网络就搭建完成了,不过,仅仅是搭建完成还是远远不够的,让我们拿它来练练手。笔者在jupyter notebook上使用CIFAR10数据集来测试我们的ResNet18模。

from resnet import ResNet18

#Use the ResNet18 on Cifar-10

import torch.optim as optim

import torchvision

import torchvision.transforms as transforms

#check gpu

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

#set hyperparameter

EPOCH = 10

pre_epoch = 0

BATCH_SIZE = 128

LR = 0.01

#prepare dataset and preprocessing

transform_train = transforms.Compose([

transforms.RandomCrop(32, padding=4),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])

transform_test = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])

trainset = torchvision.datasets.CIFAR10(root=’../data’, train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=’../data’, train=False, download=True, transform=transform_test)

testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

#labels in CIFAR10

classes = (‘plane’, ‘car’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’)

#define ResNet18

net = ResNet18().to(device)

#define loss funtion & optimizer

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

然后开始跑模型:

#train

for epoch in range(pre_epoch, EPOCH):

print(‘\nEpoch:%d’ % (epoch + 1))

net.train()

sum_loss = 0.0

correct = 0.0

total = 0.0

for i, data in enumerate(trainloader, 0):

#prepare dataset

length = len(trainloader)

inputs, labels = data

inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()

#forward & backward

outputs = net(inputs)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

#print ac & loss in each batch

sum_loss += loss.item()

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += predicted.eq(labels.data).cpu().sum()

print(‘[epoch:%d, iter:%d] Loss:%.03f| Acc:%.3f%%’

% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))

#get the ac with testdataset in each epoch

print(‘Waiting Test…’)

with torch.no_grad():

correct = 0

total = 0

for data in testloader:

net.eval()

images, labels = data

images, labels = images.to(device), labels.to(device)

outputs = net(images)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum()

print(‘Test\’s ac is:%.3f%%’ % (100 * correct / total))

print(‘Train has finished, total epoch is%d’ % EPOCH)

如果不出意外,这个模型就已经跑起来了,到这里,咱们就已经完成的实现了一个ResNet18网络,这个模型的jupyter notebook源码我已经放到了github上,如果这片文章对你有帮助,那就给我star一下吧:samcw/ResNet18-Pytorch​github.compytorch中resnet_通过Pytorch实现ResNet18

参考:

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141371.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • nginx指令详解_考试说明全解

    nginx指令详解_考试说明全解常见的命令有:nginx-sreopen#重启Nginxnginx-sreload#重新加载Nginx配置文件,然后以优雅的方式重启Nginxnginx-sstop#强制停止Nginx服务nginx-squit#优雅地停止Nginx服务(即处理完所有请求后再停止服务)nginx-t#检测配置文件是否有语法错误,然后退出nginx-?,-h#打开帮助信息nginx-v#显示版本信息并退出nginx-V#显示版本和配置选项信息,然后退出

    2025年5月25日
    3
  • SpringBoot之SpringApplication初始化

    SpringBoot之SpringApplication初始化SpringApplication的初始化之前已经分析了引导类上的@SpringBootApplication注解,接下来继续分析main方法,只调用了一句SpringApplication.run(SpringbootApplication.class,args),就启动了web容器,我们看看run方法里面做了什么publicstaticConfigurableApplicationContextrun(Class<?>[]primarySources,String[]ar

    2025年8月26日
    8
  • 光流法原理概述「建议收藏」

    光流法原理概述「建议收藏」光流的概念是Gibson在1950年首先提出来的。它是空间运动物体在观察成像平面上的像素运动的瞬时速度,是利用图像序列中像素在时间域上的变化以及相邻帧之间的相关性来找到上一帧跟当前帧之间存在的对应关系,从而计算出相邻帧之间物体的运动信息的一种方法。一般而言,光流是由于场景中前景目标本身的移动、相机的运动,或者两者的共同运动所产生的。    简单来说,光流是空间运动物体在观测成像平面上

    2022年7月23日
    17
  • MySQL常见约束条件「建议收藏」

    MySQL常见约束条件「建议收藏」约束条件:限制表中的数据,保证添加到数据表中的数据准确和可靠性!凡是不符合约束的数据,插入时就会失败!约束条件在创建表时可以使用,也可以修改表的时候添加约束条件1、约束条件分类:1)notnull:非空约束,保证字段的值不能为空s_nameVARCHAR(10)NOTNULL,#非空2)default:默认约束,保证字段总会有值,即使没有插入值,都会有默认值!…

    2022年10月13日
    3
  • 电阻和电容的识别_电容电阻怎么区分

    电阻和电容的识别_电容电阻怎么区分一、贴片电阻阻值的读法贴片电阻的阻值通常以数字形式直接标注在电阻的表面,所以读电阻的阻值直接看电阻表面的数字即可。通常情况下有三种表示方法:(1)、由三个数字组成,表明电阻的误差是±5%。前面两位是有效数字,第三位数字表示乘零的倍数,即10的几次方,基本单位是Ω。例如:103,1和0是有效数字直接写下来即可,3表示乘零倍率,也就是10的2次方,所以103表示的阻值就是1010^3=1010…

    2022年8月21日
    7
  • Node脚手架编写初学者教程

    Node脚手架编写初学者教程

    2022年3月4日
    30

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号