cifar10数据集pytorch

cifar10数据集pytorchcifar10 数据集 pytorch

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档



提示:以下是本篇文章正文内容,下面案例可供参考

cifar10数据

导入库

import torch import torchvision import torchvision.transforms as transforms import ssl from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F #交叉熵 import torch.optim as optim import matplotlib.pyplot as plt #图像绘制 import numpy as np import time #时间 

导入数据集

关于cifar10数据集,可以访问它的官网http://www.cs.toronto.edu/~kriz/cifar.html

transform = transforms.Compose( [transforms.RandomHorizontalFlip(), transforms.RandomGrayscale(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]#数据类型的转化,以及将数据进行归一化 ) trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2) #训练集 testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) #测试集 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')#数据的分类 

我查了许多代码,有些代码里会出现下面的一行代码,这个是和前面import ssl相对应的。

ssl._create_default_https_context = ssl._create_unverified_context # 解决访问https时不受ssl信任证书的问题 

定义网络

这里采用的是简单网络处理的具体代码如下:

class Net(nn.Module):#简单网络 def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5)#卷积 self.pool = nn.MaxPool2d(2, 2)#池化 self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84)#全连接层 self.fc3 = nn.Linear(84, 10) def forward(self,x):#构建模型 x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x 

标定义损失函数和优化器

criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 
 for epoch in range(20): timestart = time.time() running_loss = 0.0 for i,data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 500 == 499: print('[%d ,%5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 500)) running_loss = 0.0 print('epoch %d cost %3f sec' % (epoch + 1, time.time()-timestart)) print('Finished Training') 

训练结果

dataiter = iter(testloader) images, labels = dataiter.__next__() imshow(torchvision.utils.make_grid(images)) print('GroundTruth:', ' '.join('%5s' % classes[labels[j]] for j in range(4))) outputs = net(Variable(images)) _, predicted = torch.max(outputs.data,1) print('Predicted:', ' '.join('%5s' % classes[labels[j]] for j in range(4))) correct = 0 total = 0 for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total)) class_correct = list(0. for i in range(10)) class_total = list(0. for i in range(10)) for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i] class_total[label] += 1 for i in range(10): print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i])) 

采用其他网络进行优化

我这里只尝试了lenet和vgg16两种网络

LeNet

class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(out)) x = F.max_pool2d(x, 2) x = out.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return out 

VGG16

class VGGTest(nn.Module): def __init__(self, pretrained=True, numClasses=10): super(VGGTest, self).__init__() self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.relu1_1 = nn.ReLU(inplace=True) self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.relu1_2 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.relu2_1 = nn.ReLU(inplace=True) self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.relu2_2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.relu3_1 = nn.ReLU(inplace=True) self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.relu3_2 = nn.ReLU(inplace=True) self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.relu3_3 = nn.ReLU(inplace=True) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.relu4_1 = nn.ReLU(inplace=True) self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.relu4_2 = nn.ReLU(inplace=True) self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.relu4_3 = nn.ReLU(inplace=True) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.relu5_1 = nn.ReLU(inplace=True) self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.relu5_2 = nn.ReLU(inplace=True) self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.relu5_3 = nn.ReLU(inplace=True) self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) # 从原始的 models.vgg16(pretrained=True) 中预设值参数值。 if pretrained: pretrained_model = torchvision.models.vgg16(pretrained=pretrained) # 从预训练模型加载VGG16网络参数 pretrained_params = pretrained_model.state_dict() keys = list(pretrained_params.keys()) new_dict = { 
      } for index, key in enumerate(self.state_dict().keys()): new_dict[key] = pretrained_params[keys[index]] self.load_state_dict(new_dict) self.classifier = nn.Sequential( # 定义自己的分类层 nn.Linear(in_features=512 * 1 * 1, out_features=256), # 自定义网络输入后的大小。 # nn.Linear(in_features=512 * 7 * 7, out_features=256), # 原始vgg16的大小是 512 * 7 * 7 ,由VGG16网络决定的,第二个参数为神经元个数可以微调 nn.ReLU(True), nn.Dropout(), nn.Linear(in_features=256, out_features=256), nn.ReLU(True), nn.Dropout(), nn.Linear(in_features=256, out_features=numClasses), ) 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/216938.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 上午10:42
下一篇 2026年3月18日 上午10:43


相关推荐

  • 上下行harq概念

    上下行harq概念参考前人一些关于 harq 总结 得出自己能够理解的东西

    2025年7月18日
    3
  • SecureCRT之激活教程

    SecureCRT之激活教程因为 SecureCRT 是付费软件 所以需要一定的方法进行激活成功教程 才能便于我们使用 接下来 开始激活成功教程教程 步骤 1 将对应得激活程序拷贝到你安装 SecureCRT 的安装目录下 如果找不到安装目录 你可以到桌面找到该软件快捷方式鼠标右键属性打开文件位置 就找到了 步骤 2 以 管理员 方式运行该注册机 来进行激活程序点击上面的 nbsp Patch nbsp 出现下面的对话框选择之后

    2026年3月26日
    1
  • github代理-github

    github代理-githubgithub 代理网址 https ghproxy com 终端命令行支持终端命令行 gitclone wget curl 等工具下载 支持 raw githubuserco com gist github com gist githubuserco com 文件下载 注意 不支持 SSHKey 方式 gitclone 下载 gitclonegitc ghproxy com https github com st

    2026年3月20日
    2
  • ADB 操作命令详解及用法大全

    ADB 操作命令详解及用法大全ADB操作命令详解及用法大全一、ADB是什么?二、ADB有什么作用?三、ADB命令语法单一设备/模拟器连接多个设备/模拟器连接四、ADB常用命令4.1基本命令4.1.1查看adb的版本信息4.1.2启动adb4.1.3停止adb4.1.4以root权限运行adbd4.1.5指定adbserver的网络端口4.1.5查询已连接的设备/模拟器列表4.2设备连接管理4.2….

    2022年7月27日
    6
  • [MAC] 编译安装和测试《魔兽世界》模拟服务端 TrinityCore

    2019独角兽企业重金招聘Python工程师标准>>>…

    2022年4月17日
    177
  • 微信公众平台开发者社区_php微擎框架

    微信公众平台开发者社区_php微擎框架一、思考开发了几个微信项目,一直在思考:如何将微信相关的处理与业务系统联系在一起?如何做到彼此分离,且易于扩展?能否开发一套独立的微信服务框架,支持各种业务应用?二、现有常用的服务框架支持多种业务应用,我们通过分层的方式来实现。将复杂的系统进行分层,将一些功能或者特有的逻辑进行封装,封装为不同的基础服务或中间件。业务层无需关心底层具体实现,只需进行简单调用、组装

    2022年8月21日
    8

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号