深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152102.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 屏幕分辨率dpi计算_hypodensity

    屏幕分辨率dpi计算_hypodensityiphone7宽2.3密集度是326一英寸,我这里有一个400px*400px的正方形由于一英寸=326,不够放,所以要用2英寸放假设我用的是400dpi*400dpi那就是占用的空间大

    2022年8月2日
    5
  • PL/SQL 学习-NVL函数[通俗易懂]

    PL/SQL 学习-NVL函数[通俗易懂]Oracle :NvlNVL函数:NVL函数是将NULL值的字段转换成默认字段输出。NVL(expr1,expr2)expr1,需要转换的字段名或者表达式。expr2,null的替代值下面是NUMBER,DATE,CHARORVARCHAR2的例子:NVL(commission_pct,0)NVL(hire_date,’01-JAN-97′)N

    2022年7月15日
    17
  • 无线桥接与中继的区别[通俗易懂]

    无线桥接与中继的区别[通俗易懂]无线桥接与中继的区别    无线桥接也就是WDS(WirelessDistributionSystem,无线分布式系统),其可以无线网络相互连接的方式构成的一个整体无线网络。WDS可把有线网络的资料,透过无线网络当中继架构来传送,借此可将网络资料传送到另外一个无线网络环境,或者是另外一个有线网络。   &nb…

    2022年4月19日
    197
  • Delphi 2007安装问题[通俗易懂]

    Delphi 2007安装问题[通俗易懂]
    安装前提是你已经下载了Delphi2007forWin32的ISO。
      Delphi2007安装程序根据不同的序列号(许可文件)来判断安装版本,一般ISO中自带的许可文件是专业版的。
      企业版和专业版的许可文件下载:delphi2007_slip.zip
      C++Builder2007的企业版许可文件(slipfileforC++Builder2007):cb2007_ent.zip
      新装方法:
      1、下载D2

    2025年6月7日
    0
  • 黑盒测试基础[通俗易懂]

    黑盒测试基础[通俗易懂]黑盒测试方法:黑盒测试也称为功能测试和数据驱动测试。它将被测软件视为一个无法打开的黑盒,主要根据功能需求设计测试用例和测试。把产品软件想象成一个只有出口和入口的黑盒。在测试过程中,你只需要知道向黑盒输入什么,知道黑盒会产生什么结果。黑盒测试方法主要有等价类划分、边界值分析、因果图、错误推测等,主要用于软件验证测试。“黑盒”法侧重于程序的外部结构,不考虑内部逻辑结构,针对测试软件界面和软件功能。“黑盒”方法是详尽的输入测试,只有当所有可能的输入都用作测试条件时,才能以这种方式检测程序中的所有错误。

    2022年10月20日
    0
  • struts2拦截器不起作用「建议收藏」

    struts2拦截器不起作用「建议收藏」为什么拦截器不起作用

    2022年10月6日
    0

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号