Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理pytorch 进行 CIFAR 10 分类 1 CIFAR 10 数据加载和处理 1 写在前面的话这一篇博文的内容主要来自于 pytorch 的官方 tutorial 然后根据自己的理解把 cifar10 这个示例讲一遍 权当自己做笔记 因为这个 cifar10 是官方 example 所以适合我们拿来先练手 至少能保证代码的正确性 之所以第一篇 pytorch 的博文 其实之前还写了篇如何安装 pytorch

pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理
Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

1、写在前面的话

这一篇博文的内容主要来自于pytorch的官方tutorial,然后根据自己的理解把cifar10这个示例讲一遍,权当自己做笔记。因为这个cifar10是官方example,所以适合我们拿来先练手,至少能保证代码的正确性。

之所以第一篇pytorch的博文(其实之前还写了篇如何安装pytorch)就用cifar10做例子,是我个人觉得先从宏观上了解一个例子的样貌是什么样的,然后我们再来针对性的学习相关的知识点,这样可能效率快一点。 所以我不光是在讲cifar10这个例子,而是在剖析这个例子,说明这些知识点属于哪个模块,该去哪儿找后续也会写相关博客进行一些细节性的讲解。


官网相关内容的链接如下:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#loading-and-normalizing-cifar10

我的系列博文

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理(本文)

Pytorch打怪路(一)pytorch进行CIFAR-10分类(2)定义卷积神经网络

Pytorch打怪路(一)pytorch进行CIFAR-10分类(3)定义损失函数和优化器
Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练
Pytorch打怪路(一)pytorch进行CIFAR-10分类(5)测试
 

2.大致流程

一般来说,使用深度学习框架我们会经过下面几个流程:

模型定义(包括损失函数的选择) —>数据处理和加载 —> 训练(可能包含训练过程可视化) —> 测试

所以我们在自己写代码的时候也基本上就按照这四个大模块四步走就ok了

官方给的这个例子呢,是先进行的第二步数据处理和加载,然后定义网络,这其实没什么关系。

所以本篇博文讲解的是  数据处理和加载 这一步的内容,当然会接着在后续博文写其他步骤。

下面我就直接上程序,并且添加我自己的一些注解,觉得有问题的欢迎提出,希望和大家多交流。

 

 

3、代码分析

首先使用
torchvision加载和归一化我们的训练数据和测试数据。

a、torchvision这个东西,实现了常用的一些深度学习的相关的图像数据的加载功能,比如cifar10、Imagenet、Mnist等等的,保存在torchvision.datasets模块中。
b、同时,也封装了一些处理数据的方法。保存在torchvision.transforms模块中
c、还封装了一些模型和工具封装在相应模型中。可以从下图一窥大貌:

 

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

细节见torchvision的官方文档链接:http://pytorch.org/docs/0.3.0/torchvision/index.html 

 

#  首先当然肯定要导入torch和torchvision,至于第三个是用于进行数据预处理的模块
import torch
import torchvision
import torchvision.transforms as transforms

#  由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor
    #  首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )
    #  把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 

    # 定义了我们的训练集,名字就叫trainset,至于后面这一堆,其实就是一个类:
    # torchvision.datasets.CIFAR10( )也是封装好了的,就在我前面提到的torchvision.datasets
    # 模块中,不必深究,如果想深究就看我这段代码后面贴的图1,其实就是在下载数据
    #(不翻墙可能会慢一点吧)然后进行变换,可以看到transform就是我们上面定义的transform
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    # trainloader其实是一个比较重要的东西,我们后面就是通过trainloader把数据传入网
    # 络,当然这里的trainloader其实是个变量名,可以随便取,重点是他是由后面的
    # torch.utils.data.DataLoader()定义的,这个东西来源于torch.utils.data模块,
    #  网页链接http://pytorch.org/docs/0.3.0/data.html,这个类可见我后面图2
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
    # 对于测试集的操作和训练集一样,我就不赘述了
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)
    # 类别信息也是需要我们给定的
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
———- update 2019/11/28 ——–
经常有读者问上面的代码中的 Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 这个是什么意思
 
前面的(0.5,0.5,0.5) 是
R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)
是三个通道的标准差
注意通道顺序是 R G B ,用过opencv的同学应该知道openCV读出来的图像是 BRG顺序。
这两个tuple数据是用来对RGB 图像做归一化的,如其名称 Normalize 所示
这里都取0.5只是一个近似的操作,实际上其均值和方差并不是这么多,但是就这个示例而言 影响可不计。
 精确值是通过分别计算R,G,B三个通道的数据算出来的,
比如你有2张图片,都是100*100大小的,那么两图片的像素点共有2*100*100 = 20 000 个;
 那么这两张图片的mean求法为:
mean_R: 这20000个像素点的R值加起来,除以像素点的总数,这里是20000;
mean_G 和mean_B 两个通道 的计算方法 一样的。
 
标准差求法:
首先标准差就是开了方的方差,所以其实就是求方差,方差公式就是我们数学上的那个求方差的公式:
Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理
也是3个通道分开算,
比如算R通道的, 这里X就为20000个像素点
各自的R值,再减去R均值,上面已经算好了;
然后平方;
然后20000个像素点相加,然后求平均除以20000,
得到R的方差,再开方得标准差。
注意!!!
如果你用的是自己创建的数据集,从头训练,那最好还是要自己统计自己数据集的这两个量
 
如果①你加载的的是
pytorch上的预训练模型,自己只是微调模型;
②或者你用了
常见的数据集比如VOC或者COCO之类的,但是用的是自己的网络结构,即pytorch上没有可选的预训练模型
那么可以使用一个
pytorch上给的通用的统计值
mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 参考官网链接:https://pytorch.org/docs/stable/torchvision/models.html   用ctrl+f 搜索mean 第一个位置就有说明

 

 

4、图片

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

 

图1

root表示存放dataset的位置,本例就是’ ./data’

train,如果为True,就创建的是trainning set,可以看到我们的trainset调用它时用的是True

           而testset调用它时,参数里填的是False

transform,这个transform是形参名,由于我们定义的变换也叫transform,所以就有transform = transform,

看起来可能有点怪,其实我们的之前的变换可以随便命名

download,如果为True,就从网上下载,如果已经有下载好的数据就不会重复下载了

——————————————————————————————————————————————

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

图2

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

 dataset:就是数据的来源,比如训练集就添入我们定义的trainset

batch_size:每批次进入多少数据,本例中填的是4

shuffle:如果为真,就打乱数据的顺序,本例为True

num_workers:用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

本例中为2。这个值是什么意思呢,就是数据读入的速度到底有多快,你选的用来加载数据的

子进程越多,那么显然数据读的就越快,这样的话消耗CPU的资源也就越多,所以这个值在自己

跑实验的时候,可以自己试一试,既不要让花在加载数据上的时间太多,也不要占用太多电脑资源

 

 

所以这第一步—-数据加载和处理,要注意的就是这些内容,如果程序运行完毕,会显示:

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

 

这里我提个小建议,就是下载数据的那个root参数,官网代码给的是’./data’,这个其实可以改成自己的位置

而且,建议改成 绝对路径 要好一点。然后由于代码可以直接从官网复制粘帖,所以这部分程序的运行的快慢,

基本就取决于下载数据的网速了,建议翻墙,可能翻墙也不见得很快,不如晚上睡觉前开始下,说不定第二天

醒来就下好了呢…….







版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/228633.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月16日 下午6:34
下一篇 2026年3月16日 下午6:34


相关推荐

  • 最全的AI插件Astute Graphics 2020全系列[通俗易懂]

    最全的AI插件Astute Graphics 2020全系列[通俗易懂]AstuteGraphics2020全系列AI插件Mac版包含了AstuteGraphics出品的全部AI插件,包含18个常用辅助功能,可以帮助用户提高平面和矢量设计的效率,不断提高你的设计工作流程。让图像处理工作更快速高效。完美兼容AdobeIllustrator2018–2020,有需要的用户不要错过哦!AstuteGraphics全系列ai插件安装教程安装AstuteGraphics全系列ai插件之前请先安装AdobeIllustratorcc2020,在AI中

    2022年5月7日
    1.1K
  • C语言switch史上最详细的讲解

    C语言switch史上最详细的讲解原文链接 https github com shellhub blog issues 41C 语言 switch 史上最详细的讲解 switch 语句允许测试变量与值列表的相等性 每个值称之为案例或者 case 程序会检查 switch 后面的值并且与 case 后面的值比对 如果相等则执行后面的代码或代码块语法 switch 在 C 语言中的语法如下 switch expression cas

    2026年3月26日
    2
  • 讯飞星火智能平台使用教程

    讯飞星火智能平台使用教程

    2026年3月14日
    4
  • 蓝桥杯单片机—-NE555频率测量

    #include<stc15f2k60s2.h>#defineucharunsignedchar//定义无符号字符类型uchar#defineuintunsignedint//定义无符号整型类型uintucharcodetab[]={0xc0,0xf9,0xa4,0xb0,0x99,0x92,0x82,0xf8,0x80,0x90,0xbf,0xff,0x8e};//数字0~9,“-”,“关”,“F”ucharyi,er,san,si,wu,…

    2022年4月12日
    50
  • 欧拉函数最全总结

    欧拉函数最全总结文章目录欧拉函数的内容一、欧拉函数的引入二、欧拉函数的定义三、欧拉函数的性质四、欧拉函数的计算方法(一)素数分解法(二)编程思维1.求n以内的所有素数2.求φ(n)3.格式化输出0-100欧拉函数表(“x?”代表十位数,“x”代表个位数)五、欧拉函数相关定理以及证明(一)定理1:缩系与欧拉函数的关系(二)定理2:缩系的充要条件(三)定理3:缩系拓展1.简单证明:(a,m)=1,(x,m)=1,故(ax,m)=1。(四)定理4:设m>1,(a,m)=1,则aφ(m)≡1(modm).1.**若ac≡bc

    2022年8月22日
    11
  • 关于@(posedge clk)和@(itf.cb)的区别

    关于@(posedge clk)和@(itf.cb)的区别一 采样 region 区别 posedgeclk 采样是在 activeregion 相当于 observeregio 这里会采样最新的值 itf cb 采样会在 preponeregio 二 具体示例代码在这里插入代码片

    2026年3月16日
    1

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号