深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152102.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • tabnine 激活码_通用破解码

    tabnine 激活码_通用破解码,https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月17日
    645
  • HashMap的hash碰撞

    HashMap的hash碰撞看了看HashMap的源码,有些心得先写下,以便以后查看,不然又要忘了,但不知道对不对,希望没误人子弟吧。主要是解释下HashMap底层实现与如何解决hash碰撞的。HashMap底层是table数组,Entry是HashMap的内部类。可以看到HashMap的key与value实际是保存在Entry中的,next是下一个Entry节点。staticfinalEntry<…

    2022年6月22日
    32
  • Java零基础学习

    Java零基础学习java0基础1.注释

    2022年6月20日
    25
  • vue上传文件夹和文件_vue打包后的文件如何运行

    vue上传文件夹和文件_vue打包后的文件如何运行<template><el-form:model=”formData”label-width=”280px”><el-form-itemlabel=”上传kbase视图文件”><el-uploadclass=”upload-demo”ref=”upload”:action=”formData.url”:he.

    2022年10月10日
    2
  • kafka 集群配置_kafka集群原理

    kafka 集群配置_kafka集群原理一、kafka简述1、简介kafka是一个高吞吐的分布式消息队列系统。特点是生产者消费者模式,先进先出(FIFO)保证顺序,自己不丢数据,默认每隔7天清理数据。消息列队常见场景:系统之间解耦合、峰值压力缓冲、异步通信。2、集群介绍(1)Kafka架构是由producer(消息生产者)、consumer(消息消费者)、borker(kafka集群的server,负责处理消息读、…

    2022年8月31日
    5
  • Activiti流程引擎_activiti工作流原理

    Activiti流程引擎_activiti工作流原理Activiti框架提供的流程引擎配置类ProcessEngineConfiguration的类图如下:下面的图是流程引擎的架构图:由上图我们可以很清楚地从全局角度了解ProcessEngineConfiguration类:1)EngineServices:该接口中定义了获取各种服务类实例对象的方法。2)ProcessEngine:继承EngineServices接口,并增…

    2022年10月20日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号