深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152102.html原文链接:https://javaforall.net

(0)
上一篇 2022年6月22日 上午10:46
下一篇 2022年6月22日 上午11:00


相关推荐

  • 微服务链路追踪有哪些_微服务网关原理

    微服务链路追踪有哪些_微服务网关原理目录Sleuth简介相关术语使用Sleuth引入依赖创建服务product-serviceorder-service启动&测试Zipkin使用Zipkin参考文章Sleuth简介Sleuth是SpringCloud的组件之一,它为SpringCloud实现了一种分布式追踪解决方案,兼容Zipkin,HTrace和其他基于日志的追踪…

    2025年8月6日
    5
  • HBase Shell命令大全「建议收藏」

    HBase Shell命令大全「建议收藏」HBase关键名称:RowKey列族columnfamily单元Cell时间戳timestampHBaseShell是官方提供的一组命令,用于操作HBase。如果配置了HBase的环境变量了,就可以知己在命令行中输入hbaseshell命令进入命令行。hbaseshellhelp命令可以通过help’命名名称’来查看命令行的具体使用查询服务器状态st…

    2022年7月16日
    18
  • MySQL详解--锁

    MySQL详解--锁锁是计算机协调多个进程或线程并发访问某一资源的机制。在数据库中,除传统的计算资源(如CPU、RAM、I/O等)的争用以外,数据也是一种供许多用户共享的资源。如何保证数据并发访问的一致性、有效性是所有数据库必须解决的一个问题,锁冲突也是影响数据库并发访问性能的一个重要因素。从这个角度来说,锁对数据库而言显得尤其重要,也更加复杂。本章我们着重讨论MySQL锁机制的特点,常见的锁问题,以及解决MySQL

    2022年6月6日
    38
  • it领域的摩尔定律_裴蜀定理

    it领域的摩尔定律_裴蜀定理每十八个月,计算机等IT产品的性能会翻一番;或者说相同性能的计算机等IT产品,每十八个月价钱会降一半。

    2022年8月6日
    7
  • tcp 粘包[通俗易懂]

    tcp 粘包[通俗易懂]tcp粘包

    2022年8月11日
    7
  • 万能激活成功教程器修改器_闪照激活成功教程软件

    万能激活成功教程器修改器_闪照激活成功教程软件第一步:下载补丁文件如果是2017.2以上版本的,需要JetbrainsCrack-2.6.6及以上版本如果是2018.1及以上版本的,需要JetbrainsCrack-2.8及以上版本本人是windows64G系统,安装的2018.1.4专业版,试过JetbrainsCrack-2.6的,只能延长有效期一年;使用JetbrainsCrack-2.8的版本,有效期到2099年12月31…

    2025年7月7日
    7

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号