SENet实战详解:使用SE-ReSNet50实现对植物幼苗的分类

SENet实战详解:使用SE-ReSNet50实现对植物幼苗的分类摘要 1 SENet 概述 Squeeze and ExcitationNe 简称 SENet 是 Momenta 胡杰团队 WMW 提出的新的网络结构 利用 SENet 一举取得最后一届 ImageNet2017 竞赛 ImageClassif 任务的冠军 在 ImageNet 数据集上将 top 5error 降低到 2 251 原先的最好成绩是 2 991 作者在文中将 SENetblock 插入到现有的多种分类网络中 都取得了不错的效果 作者的动机是希望显式地建模

摘要

1、SENet概述

​ Squeeze-and-Excitation Networks(简称 SENet)是 Momenta 胡杰团队(WMW)提出的新的网络结构,利用SENet,一举取得最后一届 ImageNet 2017 竞赛 Image Classification 任务的冠军,在ImageNet数据集上将top-5 error降低到2.251%,原先的最好成绩是2.991%。

作者在文中将SENet block插入到现有的多种分类网络中,都取得了不错的效果。作者的动机是希望显式地建模特征通道之间的相互依赖关系。另外,作者并未引入新的空间维度来进行特征通道间的融合,而是采用了一种全新的「特征重标定」策略。具体来说,就是通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征。

通俗的来说SENet的核心思想在于通过网络根据loss去学习特征权重,使得有效的feature map权重大,无效或效果小的feature map权重小的方式训练模型达到更好的结果。SE block嵌在原有的一些分类网络中不可避免地增加了一些参数和计算量,但是在效果面前还是可以接受的 。Sequeeze-and-Excitation(SE) block并不是一个完整的网络结构,而是一个子结构,可以嵌到其他分类或检测模型中。

2、SENet 结构组成详解

上述结构中,Squeeze 和 Excitation 是两个非常关键的操作,下面进行详细说明。

img

上图是SE 模块的示意图。给定一个输入 x,其特征通道数为 C ′ {C}’ C,通过一系列卷积等一般变换后得到一个特征通道数为C 的特征。通过下面的三个操作还重标前面得到的特征:

1、Squeeze 操作,顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用的。

2、 Excitation 操作,它是一个类似于循环神经网络中门的机制。通过参数 w 来为每个特征通道生成权重,其中参数 w 被学习用来显式地建模特征通道间的相关性。

3、 Reweight 操作,将 Excitation 的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定。

3、详细的计算过程

首先 F t r F_{tr} Ftr这一步是转换操作(严格讲并不属于SENet,而是属于原网络,可以看后面SENet和Inception及ResNet网络的结合),在文中就是一个标准的卷积操作而已,输入输出的定义如下表示:

img

那么这个 F t r F_{tr} Ftr的公式就是下面的公式1(卷积操作, V c V_{c} Vc表示第c个卷积核, X s X^{s} Xs表示第s个输入)。

img
F t r F_{tr} Ftr得到的U就是Figure1中的左边第二个三维矩阵,也叫tensor,或者叫C个大小为HW的feature map。而uc表示U中第c个二维矩阵,下标c表示channel。
接下来就是Squeeze操作,公式非常简单,就是一个global average pooling:
img
因此公式2就将的 H × W × C H \times W \times C H×W×C输入转换成 1 × 1 × C 1 \times 1 \times C 1×1×C的输出,对应Figure1中的Fsq操作。为什么会有这一步呢?这一步的结果相当于表明该层C个feature map的数值分布情况,或者叫全局信息。
再接下来就是Excitation操作,如公式3。直接看最后一个等号,前面squeeze得到的结果是z,这里先用W1乘以z,就是一个全连接层操作,W1的维度是 C / r × C C/ r \times C C/r×C,这个r是一个缩放参数,在文中取的是16,这个参数的目的是为了减少channel个数从而降低计算量。又因为z的维度是 1 × 1 × C 1 \times 1\times C 1×1×C所以W1z的结果就是 1 × 1 × C / r 1 \times 1 \times C / r 1×1×C/r;然后再经过一个ReLU层,输出的维度不变;然后再和W2相乘,和W2相乘也是一个全连接层的过程,W2的维度是 C × C / r C \times C/r C×C/r,因此输出的维度就是 1 × 1 × C 1 \times 1 \times C 1×1×C;最后再经过sigmoid函数,得到s:
img
也就是说最后得到的这个s的维度是 1 × 1 × C 1 \times 1 \times C 1×1×C,C表示channel数目。这个s其实是本文的核心,它是用来刻画tensor U中C个feature map的权重。而且这个权重是通过前面这些全连接层和非线性层学习得到的,因此可以end-to-end训练。这两个全连接层的作用就是融合各通道的feature map信息,因为前面的squeeze都是在某个channel的feature map里面操作。
在得到s之后,就可以对原来的tensor U操作了,就是下面的公式4。也很简单,就是channel-wise multiplication,什么意思呢? u c u_{c} uc是一个二维矩阵, s c s_{c} sc是一个数,也就是权重,因此相当于把矩 u c u_{c} uc阵中的每个值都乘以 s c s_{c} sc。对应Figure1中的Fscale。







img

SENet 在具体网络中应用(代码实现SE_ResNet)

介绍完具体的公式实现,下面介绍下SE block怎么运用到具体的网络之中。

img

上图是将 SE 模块嵌入到 Inception 结构的一个示例。方框旁边的维度信息代表该层的输出。

这里我们使用 global average pooling 作为 Squeeze 操作。紧接着两个 Fully Connected 层组成一个 Bottleneck 结构去建模通道间的相关性,并输出和输入特征同样数目的权重。我们首先将特征维度降低到输入的 1/16,然后经过 ReLu 激活后再通过一个 Fully Connected 层升回到原来的维度。这样做比直接用一个 Fully Connected 层的好处在于:

1)具有更多的非线性,可以更好地拟合通道间复杂的相关性;

2)极大地减少了参数量和计算量。然后通过一个 Sigmoid 的门获得 0~1 之间归一化的权重,最后通过一个 Scale 的操作来将归一化后的权重加权到每个通道的特征上。

除此之外,SE 模块还可以嵌入到含有 skip-connections 的模块中。上右图是将 SE 嵌入到 ResNet 模块中的一个例子,操作过程基本和 SE-Inception 一样,只不过是在 Addition 前对分支上 Residual 的特征进行了特征重标定。如果对 Addition 后主支上的特征进行重标定,由于在主干上存在 0~1 的 scale 操作,在网络较深 BP 优化时就会在靠近输入层容易出现梯度消散的情况,导致模型难以优化。

目前大多数的主流网络都是基于这两种类似的单元通过 repeat 方式叠加来构造的。由此可见,SE 模块可以嵌入到现在几乎所有的网络结构中。通过在原始网络结构的 building block 单元中嵌入 SE 模块,我们可以获得不同种类的 SENet。如 SE-BN-Inception、SE-ResNet、SE-ReNeXt、SE-Inception-ResNet-v2 等等。

本例通过实现SE-ResNet,来显示如何将SE模块嵌入到ResNet网络中。SE-ResNet模型如下图:

img

实战详解

1、数据集

在工程的根目录新建data文件夹,获取数据集后,将trian和test解压放到data文件夹下面,如下图:

image-20211021182712396

2、安装库,并导入需要的库

本项目用到pretrainedmodels,这里有seresenet的预训练模型。安装方法:

pip install pretrainedmodels 

安装完成后,导入到项目中。

import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from dataset.dataset import SeedlingData from torch.autograd import Variable import pretrainedmodels 

3、设置全局参数

设置使用GPU,设置学习率、BatchSize、epoch等参数

# 设置全局参数 modellr = 1e-4 BATCH_SIZE = 16 EPOCHS = 50 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

4、数据预处理

数据处理比较简单,没有做复杂的尝试,有兴趣的可以加入一些处理。

# 数据预处理 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) 

5、数据读取

然后我们在dataset文件夹下面新建 init.py和dataset.py,在mydatasets.py文件夹写入下面的代码:

说一下代码的核心逻辑。

第一步 建立字典,定义类别对应的ID,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

代码如下:

# coding:utf8 import os from PIL import Image from torch.utils import data from torchvision import transforms as T from sklearn.model_selection import train_test_split Labels = { 
   'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11} class SeedlingData (data.Dataset): def __init__(self, root, transforms=None, train=True, test=False): """ 主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据 """ self.test = test self.transforms = transforms if self.test: imgs = [os.path.join(root, img) for img in os.listdir(root)] self.imgs = imgs else: imgs_labels = [os.path.join(root, img) for img in os.listdir(root)] imgs = [] for imglable in imgs_labels: for imgname in os.listdir(imglable): imgpath = os.path.join(imglable, imgname) imgs.append(imgpath) trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42) if train: self.imgs = trainval_files else: self.imgs = val_files def __getitem__(self, index): """ 一次返回一张图片的数据 """ img_path = self.imgs[index] img_path=img_path.replace("\\",'/') if self.test: label = -1 else: labelname = img_path.split('/')[-2] label = Labels[labelname] data = Image.open(img_path).convert('RGB') data = self.transforms(data) return data, label def __len__(self): return len(self.imgs) 

然后我们在train.py调用SeedlingData读取数据 ,记着导入刚才写的dataset.py(from mydatasets import SeedlingData)

# 读取数据 dataset_train = SeedlingData('data/train', transforms=transform, train=True) dataset_test = SeedlingData("data/train", transforms=transform_test, train=False) # 导入数据 train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False) 

6、设置模型

  • 设置loss函数为nn.CrossEntropyLoss()。
  • 设置模型为se_resnet50,修改最后一层全连接输出改为12。
  • 优化器设置为adamw。
# 实例化模型并且移动到GPU criterion = nn.CrossEntropyLoss() model_ft = pretrainedmodels.__dict__['se_resnet50'](num_classes=1000, pretrained='imagenet') model_ft.fc = classifier = nn.Sequential( nn.Linear(2048, 512), nn.LeakyReLU(True), nn.Dropout(0.5), nn.Linear(512, 12), ) model_ft.to(DEVICE) # 选择简单暴力的Adam优化器,学习率调低 optimizer = optim.AdamW(model_ft.parameters(), lr=modellr) 

7、定义训练和验证函数

def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" modellrnew = modellr * (0.1  (epoch // 50)) print("lr:", modellrnew) for param_group in optimizer.param_groups: param_group['lr'] = modellrnew # 定义训练过程 def train(model, device, train_loader, optimizer, epoch): model.train() sum_loss = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item())) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss)) # 验证过程 def val(model, device, test_loader): model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) _, pred = torch.max(output.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) # 训练 for epoch in range(1, EPOCHS + 1): adjust_learning_rate(optimizer, epoch) train(model_ft, DEVICE, train_loader, optimizer, epoch) val(model_ft, DEVICE, test_loader) torch.save(model_ft, 'model.pth') 

8、测试

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:

测试集存放的目录如下图:

image-20211021184219787

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat', 'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet') 

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) 

第三步 加载model,并将模型放在DEVICE里。

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = torch.load("model.pth") model.eval() model.to(DEVICE) 

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

path = 'data/test/' testList = os.listdir(path) for file in testList: img = Image.open(path + file) img = transform_test(img) img.unsqueeze_(0) img = Variable(img).to(DEVICE) out = model(img) # Predict _, pred = torch.max(out.data, 1) print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()])) 

第二种,使用自定义的Dataset读取图片。前三步同上,差别主要在第四步。读取数据的时候,使用Dataset的SeedlingData读取。

dataset_test =SeedlingData('data/test/', transform_test,test=True) print(len(dataset_test)) # 对应文件夹的label for index in range(len(dataset_test)): item = dataset_test[index] img, label = item img.unsqueeze_(0) data = Variable(img).to(DEVICE) output = model(data) _, pred = torch.max(output.data, 1) print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()])) index += 1 

关注公众号,回复“senet实战”,获取代码、数据集和模型。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/233404.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 图解 Vue 响应式原理

    图解 Vue 响应式原理最近部门分享,有同学提到了Vue响应式原理,大家在讨论时,发现一些同学对这一知识理解还不够深入,不能形成一个闭环,为了帮助大家理解这个问题,我重新过了一下Vue源码,并整理了多张流程图,便于大家理解。Vue初始化模板渲染组件渲染本文Vue源码版本:2.6.11,为了便于理解,均有所删减。本文将从以下两个方面进行探索:从Vue初始化,到首次渲染生成DOM的流程。从Vue数据修改,到页面更新DOM的流程。Vue初始化先从最简单的一段Vue

    2022年4月30日
    38
  • 国内教育邮箱有什么用_学校教育网邮箱

    国内教育邮箱有什么用_学校教育网邮箱教育优惠,是一项针对于在校大学生和教职员工推出的特殊优惠活动。一些公司会将旗下产品或服务以一定的折扣,甚至免费提供给高校师生。想想自己上大学的时候啥都不知道,毕业后才发现浪费了这么多优秀的资源.如果你还是一名在校大学生,那么就不要错过以下的这些教育优惠了,不然真的是失去后才会追悔莫及.申请教育优惠一般有一个前提,那就是要有一个学校提供的「校园邮箱」.在国内,也就是@xxx.edu.cn结尾……

    2022年9月23日
    1
  • 我是如何从零开始 Web 前端自学之路的?

    我是如何从零开始 Web 前端自学之路的?作者|六小登登责编|屠敏从2013年专科毕业开始,一路跌跌撞撞走了很多弯路,做过餐厅服务员,进过工厂干过流水线,做过客服,干过电话销售可以说经历相当的“丰富”…

    2022年5月20日
    34
  • matlab求两向量夹角_MATLAB基础练习(一)

    matlab求两向量夹角_MATLAB基础练习(一)1、按要求写出实现该功能的代码(1)使用方括号“[]”操作符产生一个列向量x,内容为1,2,4,7(2)使用方括号“[]”操作符产生一个行向量x,内容为1,2,4,7(3)使用冒号“:”操作符产生一个行向量x,内容为9,7,5,3,1(4)使用方括号“[]”操作符产生一个二维数组A,第1行为9,4,5,1;第2行为1,0,4,7(5)使用zeros函数产生一个3*2的二维数组A,使用one…

    2022年8月30日
    2
  • 走过2011—年终总结

    走过2011—年终总结

    2021年9月7日
    55
  • Android 自定义 ViewPager 打造千变万化的图片切换效果

    Android 自定义 ViewPager 打造千变万化的图片切换效果转载请标明出处:http://blog.csdn.net/lmj623565791/article/details/38026503 记得第一次见到ViewPager这个控件,瞬间爱不释手,做东西的主界面通通ViewPager,以及图片切换也抛弃了ImageSwitch之类的,开始让ViewPager来做。时间长了,ViewPager的切换效果觉得枯燥,形成了审美疲劳~~我们需要改变,今天教大家如

    2022年7月22日
    9

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号