深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152102.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • android开发环境搭建——android studio

    android开发环境搭建——android studio文章目录一、安装jdk二、下载包含androidsdk的androidstudio三、安装AndroidStudio四、配置AndroidStudio五、第一个helloworld六、Androidsdk环境配置七、配置androidmanager镜像八、取消安装的时候设置的代理九、安装模拟器十、运行安卓程序1、下载jdk、androidsdk、eclipse、adt2、安装…

    2022年7月23日
    9
  • pycharm 2021.2.3激活码(已测有效)

    pycharm 2021.2.3激活码(已测有效),https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月14日
    1.4K
  • linux端口转发技术(单端口分发)

    端口转发映射的程序叫rinetd,下载地址,直接manke编译安装即可。12345678910111213141516[root@PortForward02 src]# wget http://www.boutell.com/r

    2022年4月18日
    34
  • 史上最全的Android面试题集锦

    史上最全的Android面试题集锦Android基本知识点1、常规知识点1、Android类加载器在Android开发中,不管是插件化还是组件化,都是基于Android系统的类加载器ClassLoader来设计的。只不过Android平台上虚拟机运行的是Dex字节码,一种对class文件优化的产物,传统Class文件是一个Java源码文件会生成一个.class文件,而Android是把所有Class文件进行合并、优化,然后…

    2022年5月11日
    41
  • Camstar CDO增加自定义字段

    Camstar CDO增加自定义字段本节讲述如何在Camstar原生CDO里加入自定义字段进入Designer,打开CDO页,找到要增加字段的CDO,打开,切换到Fields页,点击下面的Add按钮。在弹出的窗口中,输入相应的数据:DataType增加的字段的数据类型,字符串、整数、浮点数、Object等FieldType字段类型,描述字段的具体用处,不同类型的数据字段长度是不同的(比如字符串的长度)Name字段名称,也是数据库表里的默认字段名称Caption字段描述,也是在Modeling配置页面里对应字段的名称点

    2025年7月1日
    4
  • 约4万个外国人名,中英对照[通俗易懂]

    约4万个外国人名,中英对照[通俗易懂]以下是一些外国人名,中英对照

    2022年9月30日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号