深度学习 pytorch cifar10数据集训练「建议收藏」

深度学习 pytorch cifar10数据集训练「建议收藏」1.加载数据集,并对数据集进行增强,类型转换官网cifar10数据集附链接:https://www.cs.toronto.edu/~kriz/cifar.html读取数据过程中,可以改变batch_size和num_workers来加快训练速度transform=transforms.Compose([#图像增强transforms.Resize(120),transforms.RandomHorizontalFlip(),

大家好,又见面了,我是你们的朋友全栈君。

1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述
读取数据过程中,可以改变batch_size和num_workers来加快训练速度


    transform=transforms.Compose([
        #图像增强
        transforms.Resize(120),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(96),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
        #转变为tensor 正则化
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
    ])

    trainset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=True,
        download=True,
        transform=transform
    )

    trainloader=data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True, #乱序
        num_workers=4,
)

    testset=tv.datasets.CIFAR10(
        root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
        train=False,
        download=True,
        transform=transform
    )

    testloader=data.DataLoader(
        testset,
        batch_size=2,
        shuffle=False,
        num_workers=2
    )

net网络:

   class Net(nn.Module):
        def  __init__(self):
            super(Net, self).__init__()
            self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
            self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
            self.max=nn.MaxPool2d(2,2)
            self.q1=nn.Linear(16*441,120)
            self.q2=nn.Linear(120,84)
            self.q3=nn.Linear(84,10)
            self.relu=nn.ReLU()
        def forward(self,x):
            x1=self.max(F.relu(self.conv1(x)))
            x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
            x3=x2.view(x2.size()[0],-1)
            x4=F.relu(self.q1(x3))
            x5=F.relu(self.q2(x4))
            x6=self.q3(x5)
            return x6

训练模型

    net=Net()
    #损失函数
    loss=nn.CrossEntropyLoss()
    opt=optim.SGD(net.parameters(),lr=0.001)

    for epoch in range(5):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            inputs,labels=data
            inputs=inputs.cuda()
            labels=labels.cuda()

            inputs,labels=Variable(inputs),Variable(labels)

            opt.zero_grad()
            net.to(torch.device('cuda:0'))
            h=net(inputs)
            cost=loss(h,labels)
            cost.backward()
            opt.step()

            running_loss+=cost.item()

            if i%2000==1999:
                print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))

                running_loss=0.0

                torch.save(net.state_dict(),r'net.pth')


                correct=0
                total=0
                for data in testloader:
                    images,labels=data
                    optputs=net(Variable(images.cuda()))
                    _,predicted=torch.max(optputs.cpu(),1)
                    total+=labels.size(0)
                    correct+=(predicted==labels).sum()


                print("准确率: %d %%" %(100*correct/total))

接下来可以直接进行训练
在这里插入图片描述
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152102.html原文链接:https://javaforall.net

(0)
上一篇 2022年6月22日 上午10:46
下一篇 2022年6月22日 上午11:00


相关推荐

  • 振动分析软件有哪些_excel控件按钮怎么控制图表

    振动分析软件有哪些_excel控件按钮怎么控制图表LightningChart是优化了GPU加速,硬件性能的制图组件,用于实时呈现超过10亿个数据点的海量数据。同时LightningChart是为了处理实时数据采集和处理而开发的,可有效利用CPU和内存资源。LightningChart包括广泛的2D,高级3D,Polar,Smith,3D饼/甜甜圈,地理地图和GIS图表以及适用于科学,工程,医学,航空,贸易,能源和其他领域的体绘制功能。当您想到振动分析时,您会想到什么?它正在成为结构工程中一种非常常见的识别方法,用于识别潜在的结构完整性问题,例如隐藏的物

    2022年10月15日
    3
  • 平定igb之“乱”

    平定igb之“乱”平定 igb 之 乱 作者 dnk admin

    2026年3月18日
    2
  • 容斥原理详解

    容斥原理详解转载于 https blog csdn net xianglunxi article details 翻译 vici cust 对容斥原理的描述容斥原理是一种重要的组合数学方法 可以让你求解任意大小的集合 或者计算复合事件的概率 描述 nbsp nbsp nbsp nbsp nbsp nbsp nbsp 容斥原理可以描述如下 nbsp nbsp nbsp nbsp nbsp nbsp nbsp nbsp nbsp 要计算几个集合并集的大小 我们要先将所有单个集合的大小计算出来 然后减

    2026年3月20日
    2
  • 测试用例-单元测试

    测试用例-单元测试单元测试 编写手册 1 简述本文主要针对如何使用 Junit 编写单元测试进行描述文中的实例基于 Junit4 所谓单元测试 即是指针对程序中的一些单元进行测试的方法这些单元在 Junit 中的最小单位为方法借助单元测试 我们可以轻松地单独测试程序中的某一个逻辑片段而不需要在意程序的外部依赖和其它逻辑接口测试单元测试只能以接口为维度进行测试只需被测试的单元逻辑正常即可工程必须编译通过并打包进行部署可以不依赖外部 测试进度不再受制于外部条件工程的外部依赖 数据库 调用

    2025年8月19日
    4
  • C语言回文字符串

    C语言回文字符串“回文串”是一个正读和反读都一样的字符串,字符串由数字和小写字母组成,比如“level”或者“abcdcba”等等就是回文串。请写一个程序判断读入的字符串是否是“回文”。输入:包含多个测试实例,每一行对应一个字符串,串长最多100字母。输出:对每个字符串,输出它是第几个,如第一个输出为”case1:”;如果一个字符串是回文串,则输出”yes”,否则输出”no”,在yes/no之前用一个空格…

    2022年6月6日
    36
  • pycharm与python的区别_python与pycharm有何区别

    pycharm与python的区别_python与pycharm有何区别Python 是一种计算机程序设计语言 是一种面向对象的动态类型语言 最初被设计用于编写自动化脚本 shell 随着版本的不断更新和语言新功能的添加 越来越多被用于独立的 大型项目的开发 PyCharm 是 Python 的专用 IDE 地位类似于 Java 的 IDEEclipse 功能齐全的集成开发环境同时提供收费版和免费版 即专业版和社区版 PyCharm 是安装最快的 IDE 且安装后的

    2026年3月27日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号