深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152102.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Linux常用下载工具推荐「建议收藏」

    Linux常用下载工具推荐「建议收藏」[color="#02368d"]Linux常用下载工具推荐转自:http://doc.zoomquiet.org/data/20060730210451/index.html

    2022年7月1日
    67
  • Java学习路线图(完整详细2021版)

    Java学习路线图(完整详细2021版)作为一个男人我感觉必须得做点什么来证明一下自己,现在我又回来了,准备把自己的节操准备补一下 另外给各位未来的Java程序员说一句,别的我不清楚,学习编程请从一而终 咱们学习编程就挺难的,有这些先驱者来带领咱们学习,咱们应该感激,而且最重要的事跟着你选定的一家一直学下去 因为每家学校的学习大纲都是不一样的,但是程序员其实都是一样的,这句话你细品!仔细的品! 我不希望你忙忙碌碌的整理那么多东西,挑肥拣瘦的,最后自己学的东西还是缺失的,要不就…

    2022年5月17日
    85
  • idea2018激活码【2021免费激活】[通俗易懂]

    (idea2018激活码)这是一篇idea技术相关文章,由全栈君为大家提供,主要知识点是关于2021JetBrains全家桶永久激活码的内容IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.net/100143.html1M3Q9SD5XW-eyJsa…

    2022年3月28日
    1.2K
  • 阿里云邮箱POP3、SMTP设置教程

    阿里云邮箱POP3、SMTP设置教程

    2021年9月21日
    270
  • 初识 GTK

    初识 GTKGTK+是一种函数库是用来帮助制作图形交互界面的。整个函数库都是由C语言来编写的。GTK+函数库通常也叫做GIMP工具包。

    2025年5月24日
    1
  • response中如何设置contentType

    response中如何设置contentTypeajax开发中,常遇到下面的几种情况:1服务端需要返回一段普通文本给客户端2服务端需要返回一段HTML代码给客户端3服务端需要返回一段XML代码给客户端4服务端需要返回一段javascript代码给客户端5服务端需要返回一段json串给客户端================================对于每一种返回类型规范的做法是要在服务端…

    2022年7月19日
    134

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号