CNN简单实战:pytorch搭建CNN对猫狗图片进行分类

CNN简单实战:pytorch搭建CNN对猫狗图片进行分类上一篇文章介绍了使用pytorch的Dataset和Dataloader处理图片数据,现在就用处理好的数据对搭建的CNN进行训练以及测试。

大家好,又见面了,我是你们的朋友全栈君。

在上一篇文章:CNN训练前的准备:pytorch处理自己的图像数据(Dataset和Dataloader),大致介绍了怎么利用pytorch把猫狗图片处理成CNN需要的数据,今天就用该数据对自己定义的CNN模型进行训练及测试。

  • 首先导入需要的包:
import torch
from torch import optim
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
  • 定义自己的CNN网络
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=16,
                kernel_size=3,
                stride=2,
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=2,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=2,
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.fc1 = nn.Linear(3 * 3 * 64, 64)
        self.fc2 = nn.Linear(64, 10)
        self.out = nn.Linear(10, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        # print(x.size())
        x = x.view(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.out(x)
        return x
  • 训练(GPU)
def train():
    train_loader, test_loader = load_data()
    print('train...')
    epoch_num = 15
    # GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = cnn().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0008)
    criterion = nn.CrossEntropyLoss().to(device)
    for epoch in range(epoch_num):
        for batch_idx, (data, target) in enumerate(train_loader, 0):
            data, target = Variable(data).to(device), Variable(target.long()).to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))

    torch.save(model.state_dict(), "model/cnn.pkl")

一共训练三轮,训练的步骤如下:

  1. 初始化模型:
model = cnn().to(device)
  1. 选择优化器以及优化算法,这里选择了Adam:
optimizer = optim.Adam(model.parameters(), lr=0.00005)
  1. 选择损失函数,这里选择了交叉熵:
criterion = nn.CrossEntropyLoss().to(device)
  1. 对每一个batch里的数据,先将它们转成能被GPU计算的类型:
 data, target = Variable(data).to(device), Variable(target.long()).to(device)
  1. 梯度清零、前向传播、计算误差、反向传播、更新参数:
optimizer.zero_grad()  # 梯度清0
output = model(data)[0]  # 前向传播
loss = criterion(output, target)  # 计算误差
loss.backward()  # 反向传播
optimizer.step()  # 更新参数
  • 测试(GPU)
def test():
    train_loader, test_loader = load_data()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load('cnn.pkl')  # load model
    total = 0
    current = 0
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)[0]

        predicted = torch.max(outputs.data, 1)[1].data
        total += labels.size(0)
        current += (predicted == labels).sum()

    print('Accuracy: %d %%' % (100 * current / total))

一开始只是进行了3轮训练,结果惨不忍睹:
在这里插入图片描述
随后训练20轮:
在这里插入图片描述
训练30轮:
在这里插入图片描述
如果想继续提高精度,可以再次增加训练轮数。

完整代码及数据我放在了GitHub,各位下载时麻烦给个follow和star!!感谢!!
链接:cnn-dogs-vs-cats

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/131132.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • textview可复制_android长按点击

    textview可复制_android长按点击有这么一个需求,用户在浏览文本信息时希望长按信息就能弹出复制的选项方便保存或者在别的页面使用这些信息.类似的,就像长按WebView或者EditText的内容就自动弹出复制选项.这里面主要是2个特点:1,用户只能浏览文本信息而不能编辑这些文本信息;2,用户对着文本信息长时间点按可以弹出”复制”选项实现复制;网上有好多种方法可实现,也比较零散,此处做个小结,希望有所帮助.1,通过继承EditT…

    2022年9月29日
    0
  • mysql左连接多个表_mysql可以创建多少张表

    mysql左连接多个表_mysql可以创建多少张表A表:姓名,学号,班级编号B表:学号,成绩C表:班级编号,班级名称最后想显示为姓名,学号,成绩,班级名称A、B表用wherea.学号=b.学号查出之后再和C表左连接sql语句如下:selecta.姓名,a.学号,b.成绩,c.班级名称fromA表aleftjoinB表bona.学号=b.学号leftjoinC表cona.班级编号=c.班级编号…

    2022年9月2日
    2
  • CultureInfo 类

    CultureInfo 类CultureInfo类 提供有关特定区域性的信息(如区域性的名称、书写系统和使用的日历)以及如何设置日期和排序字符串的格式。命名空间:System.Globalization程序集:mscorlib(在mscorlib.dll中)varExpCollDivStr=ExpCollDivStr;ExpCollDivStr=ExpCollDiv

    2022年6月19日
    25
  • 动态规划背包问题

    动态规划背包问题一、0-1背包1.      有n个重量和价值分别为wi,vi的物品。从这些物品中挑选出总重量不超过W的物品,求所有挑选方案中价值总和的最大值。(1<=n<=100,1<=wi,vi<=100,1<=W<=10000)样例输入:4231234225 42133457910样例输出…

    2022年7月26日
    5
  • js弹出确认取消对话框_vs点击按钮弹出对话框

    js弹出确认取消对话框_vs点击按钮弹出对话框if(window.confirm(‘你确定要执行删除操作吗?’)){alert(“您点击了确定”);}else{alert(“您点击了取消”);returnfalse;}

    2022年10月25日
    0
  • vue.config.js打包优化(有效)「建议收藏」

    vue.config.js打包优化(有效)「建议收藏」//百度上的资料五花八门让人眼花缭乱,别急,这时候我替你亲身经历了,有需要的可以参考下,先上效果图,以免你们以为我吹牛逼,嘻嘻未优化之前的//感觉太大了抬它优化之后的废话不多说了,上代码是重点这些是必要的下载/*cnpminstallimage-webpack-loader–save-devcnpminstallcompression-webpack-plugin–save-devcnpminstalluglifyjs-webpack-plugin–sa

    2022年6月12日
    89

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号