Pytroch入坑 3. 自己的人脸数据+迁移学习(resnet18)

Pytroch入坑 3. 自己的人脸数据+迁移学习(resnet18)本文转载自:http://www.zhongruitech.com/856941441.html0.前言之前是使用了mnist数据,且网络结构比较简单,针对自己的数据,如何使用更复杂、经典的网络呢?1.数据集目标是人脸识别,可以看做一个多分类问题,本次实验的数据集为ferest,共200个人,1400张38080图片,比较小。分为train和val两个目录,每个目录下都有200个…

大家好,又见面了,我是你们的朋友全栈君。

本文转载自:http://www.zhongruitech.com/856941441.html

0.前言

之前是使用了mnist数据,且网络结构比较简单,针对自己的数据,如何使用更复杂、经典的网络呢?

1.数据集

目标是人脸识别,可以看做一个多分类问题,本次实验的数据集为ferest,共200个人,1400张38080图片,比较小。
在这里插入图片描述
分为 train 和 val两个目录,每个目录下都有200个子目录。

资源可下载

https://download.csdn.net/download/sinat_37787331/10383836

注意:训练和测试的目录名字和数量必须保持一致,子目录内可以没有图片。

附批量删除、批量改格式的代码

#!/usr/bin/python
# -*- coding: utf-8 -*-
import os
def del_files(path):
  for root , dirs, files in os.walk(path):
    for name in files:
      if name.endswith(".png"):
        os.remove(os.path.join(root, name))

        print ("Delete File: " + os.path.join(root, name))
# test
if __name__ == "__main__":
  path = '/home/syj/Documents/datas/2'
  del_files(path)


gai hou zhui

#!/usr/bin/python
# -*- coding: utf-8 -*-
import os

def model_extentsion(path,before_ext,ext):
    for name in os.listdir(path):
        full_path=os.path.join(path,name)
        if os.path.isfile(full_path):
            split_path=os.path.splitext(full_path)
            pwd_name=split_path[0]
            pwd_ext=split_path[1]
            before_ext1="."+before_ext
            if pwd_ext == before_ext1:
                ext1="."+ext
                pwd_name+=ext1
                re_name=os.path.join(path,pwd_name)
                os.renames(full_path, re_name) 

        else:
            model_extentsion(full_path,before_ext,ext) 

model_extentsion("/home/syj/Documents/datas/Feret/train",'tif', "png")

2.数据加载

这次加载的是自己的数据,大体分为两种

第一种:图片文件夹+txt文档

可借鉴 http://www.bubuko.com/infodetail-2304938.html

**第二种:训练集和测试集分开,且每一类文件都放在同一子目录下。**本文采用这种方法

# 数据人脸

train_data = torchvision.datasets.ImageFolder('/home/syj/Documents/datas/Feret/train',
                                            transform=transforms.Compose([
                                                transforms.Resize(28),

                                                transforms.ToTensor(),
                                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

                                            ]))
                                            

# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.ImageFolder('/home/syj/Documents/datas/Feret/val',
                                            transform=transforms.Compose([
                                                transforms.Resize(28),

                                                transforms.ToTensor(),
                                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                            ]))
test_loader = Data.DataLoader(dataset=test_data, batch_size=20, shuffle=True)


    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),  ##224*224为resnet18输入图片尺寸
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),  #归一化
    }

主要是通过 torchvision.datasets.ImageFolder 这个函数实现的,很方便

具体的归一化等操作介绍可以参考 https://blog.csdn.net/Hungryof/article/details/76649006

3.迁移学习,加载resnet18模型,并进行fine-tuning

官方提供了许多经典的模型,如alnex,vgg,resnet,并且有训练过的参数,可以用来迁移学习

# model_ft = models.resnet18(pretrained=True)
# num_ftrs = model_ft.fc.in_features
# model_ft.fc = nn.Linear(num_ftrs, 200)

三行代码就搭建好了网络,会自动下载resnet18,只是把最后一层fc层由1000(Imaginenet)改为200就行了

4.模型保存和加载

有两种方法,一种只保存参数,一种全保存,后者简单但存储量大,我用的是后者

model_ft = torch.load('/home/syj/Documents/model/resnet18_0.003.pkl')

#torch.save(model_ft, '/home/syj/Documents/model/resnet18_0.003.pkl')

5.结果

我跑了9个epoch,200类的acc在72%左右,接近理论,花了4分钟(gt940m)

可以参考 http://www.cnblogs.com/denny402/p/7520063.html

6.完整代码

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import matplotlib as mpl
import matplotlib.pyplot as plt

def train_model(model, criterion, optimizer, scheduler, num_epochs=1):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for data in dataloders[phase]:
                # get the inputs
                inputs, labels = data

                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.data[0]
                running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    train_loss.append(loss.data[0] / 15)
                    train_acc.append(torch.sum(preds == labels.data) / 15)
                else:
                    test_loss.append(loss.data[0] / 15)
                    test_acc.append(torch.sum(preds == labels.data) / 15)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print('{} Loss {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

if __name__ == '__main__':

    # data_transform, pay attention that the input of Normalize() is Tensor and the input of RandomResizedCrop() or RandomHorizontalFlip() is PIL Image
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    # your image data file
    data_dir = '/home/syj/Documents/datas/Feret'
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x]) for x in ['train', 'val']}
    # wrap your data and label into Tensor
    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=10,
                                                 shuffle=True,
                                                 num_workers=10) for x in ['train', 'val']}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

    # use gpu or not
    use_gpu = torch.cuda.is_available()

    # get model and replace the original fc layer with your fc layer
    # model_ft = models.resnet18(pretrained=True)
    # num_ftrs = model_ft.fc.in_features
    # model_ft.fc = nn.Linear(num_ftrs, 200)
    model_ft = torch.load('/home/syj/Documents/model/resnet18_0.003.pkl')

    ##paint
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    if use_gpu:
        model_ft = model_ft.cuda()

    # define loss function
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

    model_ft = train_model(model=model_ft,
                           criterion=criterion,
                           optimizer=optimizer_ft,
                           scheduler=exp_lr_scheduler,
                           num_epochs=2)

    #torch.save(model_ft, '/home/syj/Documents/model/resnet18_0.003.pkl')

    ##paint
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.plot(train_loss, lw = 1.5, label = 'train_loss')

    plt.subplot(2, 2, 2)
    plt.plot(train_acc, lw = 1.5, label = 'train_acc')

    plt.subplot(2, 2, 3)
    plt.plot(test_loss, lw = 1.5,label = 'loss')

    plt.subplot(2, 2, 4)
    plt.plot(test_acc, lw = 1.5, label = 'acc')
    plt.savefig("resnet18_0.01-10.jpg")
    plt.show()
    print(dataset_sizes)

----------
train Loss: 0.1916 Acc: 0.8083
val Loss: 0.0262 Acc: 0.9778
Epoch 24/24
----------
train Loss: 0.2031 Acc: 0.8250
val Loss: 0.0269 Acc: 1.0000
Training complete in 4m 19s
Best val Acc: 1.000000

'''

'''  lr=0.003
Epoch 9/9
----------
train Loss: 0.1358 Acc: 0.6710
val Loss: 0.1135 Acc: 0.6575
Training complete in 9m 43s
Best val Acc: 0.657500
'''
''' lr=0.01 15
Epoch 9/9
----------
train Loss: 0.0415 Acc: 0.8530
val Loss: 0.0802 Acc: 0.7225
Training complete in 10m 6s
Best val Acc: 0.722500
'''

''' 0.01 10
Epoch 38/39
----------
train Loss: 0.0509 Acc: 0.8640
val Loss: 0.1262 Acc: 0.7325
Epoch 39/39
----------
train Loss: 0.0508 Acc: 0.8520
val Loss: 0.1396 Acc: 0.7200
Training complete in 4m 13s
Best val Acc: 0.737500

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/141489.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 论文阅读报告_小论文

    论文阅读报告_小论文FactorizingYAGOScalableMachineLearningforLinkedData关联数据的可扩展机器学习分解发表于WWW2012–Session:CreatingandUsingLinksbetweenDataObjects摘要:语义Web的链接开放数据(LOD)云中已经发布了大量的结构化信息,而且它们的规模仍在快速增长。然而,由于LOD的大小、部分数据不一致和固有的噪声,很难通过推理和查询访问这些信息。本文提出了一种高效的LOD数据关系学习方

    2025年8月20日
    2
  • ringbuffer 无锁队列_wear ring

    ringbuffer 无锁队列_wear ring最近常收到SOD框架的朋友报告的SOD的SQL日志功能报错:文件句柄丢失。经过分析得知,这些朋友使用SOD框架开发了访问量比较大的系统,由于忘记关闭SQL日志功能所以出现了很高频率的日志写入操作,从而偶然引起错误。后来我建议只记录出错的或者执行时间较长的SQL信息,暂时解决了此问题。但是作为一个热心造轮子的人,一定要看看能不能造一个更好的轮子出来。前面说的错误原因已经很直白了,就是频繁的日志写入导…

    2025年10月18日
    3
  • win10共享打印错误0x0000006_win7打印机共享出现0x000006d9错误的解决方法

    win10共享打印错误0x0000006_win7打印机共享出现0x000006d9错误的解决方法这两天在WIN7上安装了一个HP1320打印机,装驱动,后来共享,在共享的时候出错了,发现问题,竟然无法共享打印机。发现出现“0x000006d9错误”,在开始的时候测试打印没问题,就是不能共享,只能说系统本身的问题了。后来经过查询资料,发现很有人说把windowsfirewall服务打开,就可以共享。我装系统的时候装了卡巴,它自动把防火墙关掉,我打开services.msc后发现这个wind…

    2022年5月14日
    146
  • python解析xps文件_xps文件的基本操作

    python解析xps文件_xps文件的基本操作最近一直研究XPS文件,目前已经解决了二进制流转XPS文件、XPS文件转二进流、XPS文件的解析、XPS文件转图片、XPS文件打印等。但是一直没有找到如何向xps文件中插入图片的方法,好烦恼啊!!!!如果那位大神有向xps文件中插入图片的方法请及时联系我谢谢,QQ470163177。本人研究的成果如下,需要的码友可以学习下。注意:xps命名空间在ReachFramework.dll中using…

    2022年6月3日
    35
  • SQL EXITS用法

    SQL EXITS用法比如在Northwind数据库中有一个查询为SELECTc.CustomerId,CompanyNameFROMCustomerscWHEREEXISTS(SELECTOrderIDFROMOrdersoWHEREo.CustomerID=c.CustomerID) 这里面的EXISTS是如何运作呢?子查询返回的是OrderId字段,可是外面的查询要找的是Cu

    2025年6月24日
    3
  • idea2019.3.4激活码【2021最新】

    (idea2019.3.4激活码)最近有小伙伴私信我,问我这边有没有免费的intellijIdea的激活码,然后我将全栈君台教程分享给他了。激活成功之后他一直表示感谢,哈哈~IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.net/100143.html…

    2022年3月30日
    52

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号