CNN训练前的准备:pytorch处理自己的图像数据(Dataset和Dataloader)

CNN训练前的准备:pytorch处理自己的图像数据(Dataset和Dataloader)pytorch的torchvision给我们提供了很多已经封装好的数据集,但是我们经常得使用自己找到的数据集,因此,想要得到一个好的训练结果,合理的数据处理是必不可少的。

大家好,又见面了,我是你们的朋友全栈君。

链接:cnn-dogs-vs-cats

  pytorch给我们提供了很多已经封装好的数据集,但是我们经常得使用自己找到的数据集,因此,想要得到一个好的训练结果,合理的数据处理是必不可少的。我们以1400张猫狗图片来进行分析。

  1. 分析数据:
    在这里插入图片描述
    在这里插入图片描述

训练集包含500张狗的图片以及500张猫的图片,测试接包含200张狗的图片以及200张猫的图片。

  1. 数据预处理:得到一个包含所有图片文件名(包含路径)和标签(狗1猫0)的列表:
def init_process(path, lens):
    data = []
    name = find_label(path)
    for i in range(lens[0], lens[1]):
        data.append([path % i, name])
        
    return data

  现有数据的命名都是有序号的,训练集中数据编号为0-499,测试集中编号为1000-1200,因此我们可以根据这个规律来读取文件名,比如参数传入:

path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])

data1就是一个包含五百个文件名以及标签的列表。find_label来判断标签是dog还是cat:

def find_label(str):
    first, last = 0, 0
    for i in range(len(str) - 1, -1, -1):
        if str[i] == '%' and str[i - 1] == '.':
            last = i - 1
        if (str[i] == 'c' or str[i] == 'd') and str[i - 1] == '/':
            first = i
            break

    name = str[first:last]
    if name == 'dog':
        return 1
    else:
        return 0

dog返回1,cat返回0。
  有了上面两个函数之后,我们经过四次操作,就可以得到四个列表:

path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])
path2 = 'cnn_data/data/training_data/dogs/dog.%d.jpg'
data2 = init_process(path2, [0, 500])
path3 = 'cnn_data/data/testing_data/cats/cat.%d.jpg'
data3 = init_process(path3, [1000, 1200])
path4 = 'cnn_data/data/testing_data/dogs/dog.%d.jpg'
data4 = init_process(path4, [1000, 1200])

随便输出一个列表的前五个:

[['cnn_data/data/testing_data/dogs/dog.1000.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1001.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1002.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1003.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1004.jpg', 1]]
  1. 利用PIL包的Image处理图片:
def Myloader(path):
    return Image.open(path).convert('RGB')
  1. 重写pytorch的Dataset类:
class MyDataset(Dataset):
    def __init__(self, data, transform, loder):
        self.data = data
        self.transform = transform
        self.loader = loder
    def __getitem__(self, item):
        img, label = self.data[item]
        img = self.loader(img)
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data)

里面有2个比较重要的函数:

  • __getitem__是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可。在这个函数里面,我们对第二步处理得到的列表进行索引,接着利用第三步定义的Myloader来对每一个路径进行处理,最后利用pytorch的transforms对RGB数据进行处理,将其变成Tensor数据。
    transform为:
transform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # 归一化
    ])

对上面四个操作做一些解释:
1)、transforms.CenterCrop(224),从图像中心开始裁剪图像,224为裁剪大小
2)、transforms.Resize((224, 224)),重新定义图像大小
3)、 transforms.ToTensor(),很重要的一步,将图像数据转为Tensor
4)、transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),归一化

  • __len__中提供迭代器索引的范围。

因此我们只需要:

train_data = data1 + data2 + data3[0:150] + data4[0:150]
train = MyDataset(train_data, transform=transform, loder=Myloader)
test_data = data3[150:200] + data4[150:200]
test= MyDataset(test_data, transform=transform, loder=Myloader)

就可以得到处理好的Dataset,其中训练集我给了1300张图片,测试集只给了100张。

  1. 通过pytorch的DataLoader对第四步得到的Dataset进行shuffle以及mini-batch操作,分成一个个小的数据集:
train_data = DataLoader(dataset=train, batch_size=5, shuffle=True, num_workers=0, pin_memory=True)
test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

最后我们只要给定义好的神经网络模型喂数据就OK了!!!
完整代码:

# -*- coding: utf-8 -*-
""" @Time : 2020/8/18 9:11 @Author :KI @File :CNN.py @Motto:Hungry And Humble """
import torch
from torch import optim
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def Myloader(path):
    return Image.open(path).convert('RGB')

#得到一个包含路径与标签的列表
def init_process(path, lens):
    data = []
    name = find_label(path)
    for i in range(lens[0], lens[1]):
        data.append([path % i, name])

    return data

class MyDataset(Dataset):
    def __init__(self, data, transform, loder):
        self.data = data
        self.transform = transform
        self.loader = loder
    def __getitem__(self, item):
        img, label = self.data[item]
        img = self.loader(img)
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data)


def find_label(str):
    first, last = 0, 0
    for i in range(len(str) - 1, -1, -1):
        if str[i] == '%' and str[i - 1] == '.':
            last = i - 1
        if (str[i] == 'c' or str[i] == 'd') and str[i - 1] == '/':
            first = i
            break

    name = str[first:last]
    if name == 'dog':
        return 1
    else:
        return 0

def load_data():
    transform = transforms.Compose([
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomVerticalFlip(p=0.5),
        transforms.CenterCrop(224),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # 归一化
    ])
    path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
    data1 = init_process(path1, [0, 500])
    path2 = 'cnn_data/data/training_data/dogs/dog.%d.jpg'
    data2 = init_process(path2, [0, 500])
    path3 = 'cnn_data/data/testing_data/cats/cat.%d.jpg'
    data3 = init_process(path3, [1000, 1200])
    path4 = 'cnn_data/data/testing_data/dogs/dog.%d.jpg'
    data4 = init_process(path4, [1000, 1200])

    train_data = data1 + data2 + data3[0:150] + data4[0:150]

    train = MyDataset(train_data, transform=transform, loder=Myloader)

    test_data = data3[150:200] + data4[150:200]
    test= MyDataset(test_data, transform=transform, loder=Myloader)

    train_data = DataLoader(dataset=train, batch_size=5, shuffle=True, num_workers=0, pin_memory=True)
    test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

    return train_data, test_data

train_data以及test_data就是我们最终需要得到的数据。

对猫狗数据分类的具体实现请见:CNN简单实战:pytorch搭建CNN对猫狗图片进行分类

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/130066.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 最新 Cocos2d-x 3.2 开发环境搭建(windows环境下)

    最新 Cocos2d-x 3.2 开发环境搭建(windows环境下)

    2021年12月5日
    47
  • JavaScript——利用正则表达式实现二代身份证号码的验证

    JavaScript——利用正则表达式实现二代身份证号码的验证HTML<divclass=”login-header”><aid=”link”>点击,弹出登录框</a></div><divclass=”box”id=”box”><divclass=”hd”id=”drop”>注册信息(可以拖拽)<spanid=”box_close”>[关闭]</span></div><divclas

    2022年6月27日
    32
  • JAVA——数组截取——调用库中方法

    JAVA——数组截取——调用库中方法1,使用Java类库中的方法System.arraycopy2,使用Java类库中的方法java.util.Arrays.copyOf3,重写myCopy(一)使用.arraycopy方法使用方法:System.arraycopy(源数组名称,源数组开始点,目标数组名称,目标数组开始点,拷贝长度);说明:将arr1数组中的一部分替换成arr2数组中的一部分可以从任意位置开始截取…

    2022年6月9日
    161
  • 安卓log日志查看工具_手机怎么查看错误日志

    安卓log日志查看工具_手机怎么查看错误日志一个完整的程序日志记录功能是必不可少的,通过日志我们可以了解程序运行详情、错误信息等,以便更好的发现及解决问题。日志可以记录到数据库、日志服务器、文件等地方,本文主要介绍文件日志。 文

    2022年8月1日
    5
  • 根据美光内存颗粒上的编码查询对应型号

    根据美光内存颗粒上的编码查询对应型号根据美光内存颗粒上的编码查询对应型号今天遇到需要查看美光内存颗粒容量的问题。美光FBGA封装的DDR颗粒上只有两行,每行5位的编码。根据美光官网上的说明,由于FBGA封装上空间的限制,不能印完整的型号信息,只能用编码表示,其中第二行的5位编码可以用于查询对应的型号信息。官方提供了FBGA&ComponentMarkingDecoder工具来查询FBGAcode对应的型号,进而就可以查找到了

    2022年6月22日
    31
  • 权限设计-系统登录用户权限设计[通俗易懂]

    权限设计-系统登录用户权限设计[通俗易懂]需求分析—场景假设需要为公司设计一个人员管理系统,并为各级领导及全体员工分配系统登录账号。有如下几个要求:1. 权限等级不同:公司领导登录后可查看所有员工信息,部门领导登录后只可查看本部门员工的信息,员工登录后只可查看自己的信息;2.访问权限不同:如公司领导登录后,可查看员工薪水分布界面,而员工则不能看到;3.操作权限不同:如系统管理员可以在信息发布界面进行增删改查发布信息

    2022年7月13日
    14

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号