PyTorch实现ResNet18

PyTorch实现ResNet18ResNet-18结构基本结点代码实现importtorchimporttorch.nnasnnfromtorch.nnimportfunctionalasFclassRestNetBasicBlock(nn.Module):def__init__(self,in_channels,out_channels,stride):super(RestNetBasicBlock,self).__init__()self.

大家好,又见面了,我是你们的朋友全栈君。

ResNet-18结构

在这里插入图片描述

基本结点

在这里插入图片描述

代码实现

import torch
import torch.nn as nn
from torch.nn import functional as F


class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        output = F.relu(self.bn1(output))
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)


class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)


class RestNet18(nn.Module):
    def __init__(self):
        super(RestNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))

        self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
                                    RestNetBasicBlock(512, 512, 1))

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out

用来预测CIFAR-10数据集

数据集

官网链接:CIFAR-10 DATASET
在这里插入图片描述

测试代码

import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from restnet18.restnet18 import RestNet18


# 用CIFAR-10 数据集进行实验

def main():
    batchsz = 128

    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')
    # model = Lenet5().to(device)
    model = RestNet18().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    for epoch in range(1000):

        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x, label = x.to(device), label.to(device)

            logits = model(x)
            # logits: [b, 10]
            # label: [b]
            # loss: tensor scalar
            loss = criteon(logits, label)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        model.eval()
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

运行结果

在这里插入图片描述
感觉挺low的,迭代50多次能达到80多的准确率

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141287.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 输入3个数a,b,c,要求输出最大值_二维数组求最大值及下标

    输入3个数a,b,c,要求输出最大值_二维数组求最大值及下标7-1 求最大值及其下标 (20分) 本题要求编写程序,找出给定的n个数中的最大值及其对应的最小下标(下标从0开始)。输入格式: 输入在第一行中给出一个正整数n(1<n≤10)。第二行输入n个整数,用空格分开。输出格式: 在一行中输出最大值及最大值的最小下标,中间用一个空格分开。 输入样例: 6 2 8 10 1 9 10 输出样例: 10 2#inc…

    2022年8月18日
    3
  • linux 渗透工具_适用于Linux的十大最佳渗透测试工具[通俗易懂]

    linux 渗透工具_适用于Linux的十大最佳渗透测试工具[通俗易懂]linux渗透工具ThisarticlecoverssomeofthebestpenetrationtestingtoolsforLinuxCybersecurityisabigconcernforbothsmallandbigorganizations.Inanagewheremoreandmorebusinessesaremov…

    2022年8月12日
    3
  • Qt Quick中的信号与槽

    在QML中,在QtQuick中,要想妥善地处理各种事件,肯定离不开信号与槽,本博的主要内容就是整理Qt中的信号与槽的内容。1.链接QML类型的已知信号QML中已有类型定义的信号分为两类:一类

    2021年12月29日
    42
  • 学c++还是学java就业「建议收藏」

    学c++还是学java就业「建议收藏」Java更偏向业务型开发,比如银行的xx管理系统,安卓手机的软件以及WEB等等。java更容易入手,学会用框架基本就能来开发,开发效率(完成的速度)相对高,当前相对C++更好就业,薪资平均水平相比C++略高(参考2014年谷歌统计数据)。C++,难度相对高,入手较难深入也难,它涉及的内容很多,特性很多,可以做一些考虑性能(并发,速度)的东西,比如各种后台服务,游戏的后台部分,C++主要更服务器打交道,当然你要用上MFC,QT等也能做界面的东西。前途还是钱途:当前的话,可能Java性价比更高。不过游戏,

    2022年7月17日
    11
  • html页面根据js名称调用需要的js

    html页面根据js名称调用需要的js

    2021年8月9日
    55
  • 常用组合数计算公式及推算[通俗易懂]

    常用组合数计算公式及推算[通俗易懂]参考:博客1博客2更多更详细请看博客2组合数的通项公式:公式1:证明:n个不同的数选择m个,第m个的选择方案为:1、选第m个:2、不选第m个:公式2:证明:性质3:证明:性质4:证明:性质5:…

    2022年7月25日
    23

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号