resnet源码pytorch_pytorch conv1d

resnet源码pytorch_pytorch conv1d#Pytorch 0.4.0 ResNet34实现cifar10分类.#@Time:2018/6/17#@Author:xfLiimporttorchvisionastvimporttorchastimporttorchvision.transformsastransformsfromtorchimportnnfromtorch.utils.da…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用


# Pytorch 0.4.0 ResNet34实现cifar10分类.
# @Time: 2018/6/17
# @Author: xfLi

import torchvision as tv
import torch as t
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
t.set_num_threads(8)


class ResidualBloak(nn.Module):
    #残差块
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBloak, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel))
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

class ResNet34(nn.Module):
    #  实现主module:ResNet34  
    #  ResNet34 包含多个layer,每个layer又包含多个residual block  
    #  用子module来实现residual block,用_make_layer函数来实现layer 
    def __init__(self, num_classes):
        super(ResNet34, self).__init__()
        #前几层图像转换
        self.pre = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1))

        # 重复的layer,分别有3,4,6,3个residual block
        self.layer1 = self._make_layer(16, 16, 3, stride=1)
        self.layer2 = self._make_layer(16, 32, 4, stride=1)
        self.layer3 = self._make_layer(32, 64, 6, stride=1)
        self.layer4 = self._make_layer(64, 64, 3, stride=1)
        #分类用的全连接
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        #构建layer,包含多个residual block
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm2d(outchannel))
        layer = []
        layer.append(ResidualBloak(inchannel, outchannel, stride, shortcut))
        for i in range(1, block_num):
            layer.append(ResidualBloak(outchannel, outchannel))
        return nn.Sequential(*layer)

    def forward(self, x):
        x = self.pre(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 7)
        x = x.view(x.size(0), -1)
        return self.fc(x)

def getData(): # 定义对数据的预处理  
    transform = transforms.Compose([
        transforms.Resize(40),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
        transforms.ToTensor()])
    #训练集
    trainset = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
    trainset_loader = DataLoader(trainset, batch_size=4, shuffle=True)
    #测试集
    testset = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)
    testset_loader = DataLoader(testset, batch_size=4, shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainset_loader, testset_loader, classes

def train(): #训练
    trainset_loader, testset_loader, _ = getData() #获取数据
    net = ResNet34(10)
    print(net)
    criterion = nn.CrossEntropyLoss()
    optimizer = t.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #优化器

    for epoch in range(1):
        for step, (inputs,labels) in enumerate(trainset_loader):
            optimizer.zero_grad() #梯度清零
            output = net(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            if step % 10 ==9:
                acc = test(net, testset_loader)
                print('Epoch', epoch, '|step ', step, 'loss: %.4f' %loss.item(), 'test accuracy:%.4f' %acc)
    print('Finished Training')
    return net

def test(net, testdata): #测试集
    correct, total = .0, .0
    for inputs, label in testdata:
        net.eval()
        output = net(inputs)
        _, predicted = t.max(output, 1) #分类结果
        total += label.size(0)
        correct += (predicted == label).sum()
    return float(correct) / total

if __name__ == '__main__':
    net = train()








版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185357.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • setCapture和releaseCapture的小应用「建议收藏」

    setCapture和releaseCapture的小应用「建议收藏」       web开发和windows开发最大的区别就是windows开发是有状态的,而web开发是无状态的,在windows中,一切操作都可以由程序来控制,除非强制执行ctrl+alt+del;但web操作就不一样了,即使执行很重要的操作,用户一点击浏览器关闭按钮,就将前面操作成果化为乌有.尽管可以在onunload事件中加些代码,让用户可以选择是否退出,但不能从根本上解决问题!    

    2022年5月3日
    42
  • github 京东自动签到_手机京东签到在哪里

    github 京东自动签到_手机京东签到在哪里京东自动签到(利用github实现)+Cooki失效解决办法京东自动签到https://ruicky.me/2020/06/05/jd-sign/参考上面这篇文章,就不转载过来了,原文已经写的很详细了。但自己实践时Sevrer酱提示Cookie失效,同时也看到此文下面有很多跟我一样情况的,所以有提示Cookie失效的请用下面链接的方法获取Cookie,记得复制出来的Cookie值要把所有空格删除。获取京东Cookiehttps://www.plus888.com/21061.html…

    2025年11月29日
    5
  • 推荐几个长期有效的免费服务器和免费vps游戏服务器亲测再用

    推荐几个长期有效的免费服务器和免费vps游戏服务器亲测再用对于新手,搞网络购买现有的服务器和VPS成本太高!这里我长期测试筛选了几个长期有效好用的服务器!(当然免费虽好请勿滥用)1.FREEWHA这个服务器已经存在10几年了,好用谷歌收录也可以。提供1.5g空间免费流量,SSL申请免费空间很少有提供SSL的!这个空间的缺点就是无法上传大文件,最好用FTP上传。如果你做博客网站,选择SQLIVE数据库的最好。比如;ZBLOG的网站程序.2.Freehosting一个非常稳定的免费空间,提供10G空间和无限流量。这个空间其他功能都需要购买,数据库链接有流量限制

    2022年10月5日
    4
  • 0xc0000225无法进系统_win7系统出现0xc0000225无法进入系统的解决方法「建议收藏」

    0xc0000225无法进系统_win7系统出现0xc0000225无法进入系统的解决方法「建议收藏」无论谁在使用电脑的时候都可能会发现出现0xc0000225无法进入系统的问题,出现0xc0000225无法进入系统让用户们很苦恼,这是怎么回事呢,出现0xc0000225无法进入系统有什么简便的处理方式呢,其实只要依照 第一步、重启计算机,开机长按F8进入安全模式; 第二步、点击开始,打开运行菜单项,运行cmd命令;很容易就能搞定了,下面就给大家讲解一下出现0xc0000225无法进入系统的快速处…

    2022年6月26日
    71
  • git的pull和fetch区别_git pull和git clone

    git的pull和fetch区别_git pull和git clonegitfetch和gitpull都可以将远端仓库更新至本地那么他们之间有什么区别呢?想要弄清楚这个问题有有几个概念不得不提。FETCH_HEAD:是一个版本链接,记录在本地的一个文件中,指向着目前已经从远程仓库取下来的分支的末端版本。commit-id:在每次本地工作完成后,都会做一个gitcommit操作来保存当前工作到本地的repo,此时会产生一个commit-id,这是一个能

    2022年8月22日
    6
  • Windows内核编程(二)-第一个内核程序

    Windows内核编程(二)-第一个内核程序第一个内核程序通过VisualStudio新建工程注意事项:大部分widnows驱动程序都是内核驱动(KernelDriver),所以本笔记不分”驱动程序”与”内核编程”,也不区分”内核模块”(KernelModule)、“驱动程序”(Driver)与”内核程序”,这些词汇统一指编译出的扩展名为”.sys”的可执行文件(并非强制扩展名为.sys),也不区分”应用层”与”用户态”。驱动分类:NT驱动最简单的驱动模型,不支持硬件特性WDM驱动在NT驱动的基础上引入的一套驱动模型,支持即

    2022年10月8日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号