pytorch实现Senet 代码详解

pytorch实现Senet 代码详解Senet 的优点 senet 的优点在于增加少量的参数便可以一定程度的提高模型的准确率 是第一个在成型的模型基础之上建立的策略 创新点非常的好 很适合自己创作新模型刷高准确率的一种方法 Senet 的结构本文的代码讲解是以 resnet50 讲解 上图便是 senet 的结构 应用于已经构造完成的 resnet 模型 只不过在加上了一层 se 结构的卷积 se 结构是在特征图最后进行的 out channels

Senet的优点

senet的优点在于增加少量的参数便可以一定程度的提高模型的准确率,是第一个在成型的模型基础之上建立的策略,创新点非常的好,很适合自己创作新模型刷高准确率的一种方法。

Senet的结构

在这里插入图片描述
本文的代码讲解是以resnet50讲解,上图便是senet的结构,应用于已经构造完成的resnet模型,只不过在加上了一层se结构的卷积。se结构是在特征图最后进行的,out_channels,作为输入,然后经历整个se结构的卷积处理,这种压缩在膨胀的过程可以看做是不同层特征图数据交融,原本进行1X1的卷积就可以加强非线性和跨通道的信息交互。

Se结构代码

 self.se = nn.Sequential( nn.AdaptiveAvgPool2d((1,1)), nn.Conv2d(filter3,filter3//16,kernel_size=1), nn.ReLU(), nn.Conv2d(filter3//16,filter3,kernel_size=1), nn.Sigmoid() ) 

整体结构

在这里插入图片描述
SE-resnet-50基本跟resnet-50没有变化,唯一变化就是加上了se这个结构,可以参考我之前写的resnet讲解对比学习,代码也基本相同。卷积的重复利用,所以写成一个板块

class Block(nn.Module): def __init__(self, in_channels, filters, stride=1, is_1x1conv=False): super(Block, self).__init__() filter1, filter2, filter3 = filters self.is_1x1conv = is_1x1conv self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Sequential( nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False), nn.BatchNorm2d(filter1), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(filter2), nn.ReLU() ) self.conv3 = nn.Sequential( nn.Conv2d(filter2, filter3, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(filter3), ) if is_1x1conv: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(filter3) ) self.se = nn.Sequential( nn.AdaptiveAvgPool2d((1,1)), nn.Conv2d(filter3,filter3//16,kernel_size=1), nn.ReLU(), nn.Conv2d(filter3//16,filter3,kernel_size=1), nn.Sigmoid() ) def forward(self, x): x_shortcut = x x1 = self.conv1(x) x1 = self.conv2(x1) x1 = self.conv3(x1) x2 = self.se(x1) x1 = x1*x2 if self.is_1x1conv: x_shortcut = self.shortcut(x_shortcut) x1 = x1 + x_shortcut x1 = self.relu(x1) return x1 

全部代码

import torch import torch.nn as nn class Block(nn.Module): def __init__(self, in_channels, filters, stride=1, is_1x1conv=False): super(Block, self).__init__() filter1, filter2, filter3 = filters self.is_1x1conv = is_1x1conv self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Sequential( nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False), nn.BatchNorm2d(filter1), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(filter2), nn.ReLU() ) self.conv3 = nn.Sequential( nn.Conv2d(filter2, filter3, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(filter3), ) if is_1x1conv: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(filter3) ) self.se = nn.Sequential( nn.AdaptiveAvgPool2d((1,1)), nn.Conv2d(filter3,filter3//16,kernel_size=1), nn.ReLU(), nn.Conv2d(filter3//16,filter3,kernel_size=1), nn.Sigmoid() ) def forward(self, x): x_shortcut = x x1 = self.conv1(x) x1 = self.conv2(x1) x1 = self.conv3(x1) x2 = self.se(x1) x1 = x1*x2 if self.is_1x1conv: x_shortcut = self.shortcut(x_shortcut) x1 = x1 + x_shortcut x1 = self.relu(x1) return x1 class senet(nn.Module): def __init__(self,cfg): super(senet,self).__init__() classes = cfg['classes'] num = cfg['num'] self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) self.conv2 = self._make_layer(64, (64, 64, 256), num[0],1) self.conv3 = self._make_layer(256, (128, 128, 512), num[1], 2) self.conv4 = self._make_layer(512, (256, 256, 1024), num[2], 2) self.conv5 = self._make_layer(1024, (512, 512, 2048), num[3], 2) self.global_average_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Sequential( nn.Linear(2048,classes) ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) x = self.global_average_pool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def _make_layer(self,in_channels, filters, num, stride=1): layers = [] block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True) layers.append(block_1) for i in range(1, num): layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False)) return nn.Sequential(*layers) def Senet(): cfg = { 
    'num':(3,4,6,3), 'classes': (10) } return senet(cfg) net = Senet() x = torch.rand((10, 3, 224, 224)) for name,layer in net.named_children(): if name != "fc": x = layer(x) print(name, 'output shaoe:', x.shape) else: x = x.view(x.size(0), -1) x = layer(x) print(name, 'output shaoe:', x.shape) 

训练结果展示

在这里插入图片描述

可直接运行的全部代码

import torch from torch import nn import torch import torch.nn as nn import torch import torch.nn as nn class Block(nn.Module): def __init__(self, in_channels, filters, stride=1, is_1x1conv=False): super(Block, self).__init__() filter1, filter2, filter3 = filters self.is_1x1conv = is_1x1conv self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Sequential( nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False), nn.BatchNorm2d(filter1), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(filter2), nn.ReLU() ) self.conv3 = nn.Sequential( nn.Conv2d(filter2, filter3, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(filter3), ) if is_1x1conv: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(filter3) ) self.se = nn.Sequential( nn.AdaptiveAvgPool2d((1,1)), nn.Conv2d(filter3,filter3//16,kernel_size=1), nn.ReLU(), nn.Conv2d(filter3//16,filter3,kernel_size=1), nn.Sigmoid() ) def forward(self, x): x_shortcut = x x1 = self.conv1(x) x1 = self.conv2(x1) x1 = self.conv3(x1) x2 = self.se(x1) x1 = x1*x2 if self.is_1x1conv: x_shortcut = self.shortcut(x_shortcut) x1 = x1 + x_shortcut x1 = self.relu(x1) return x1 class senet(nn.Module): def __init__(self,cfg): super(senet,self).__init__() classes = cfg['classes'] num = cfg['num'] self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) self.conv2 = self._make_layer(64, (64, 64, 256), num[0],1) self.conv3 = self._make_layer(256, (128, 128, 512), num[1], 2) self.conv4 = self._make_layer(512, (256, 256, 1024), num[2], 2) self.conv5 = self._make_layer(1024, (512, 512, 2048), num[3], 2) self.global_average_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Sequential( nn.Linear(2048,classes) ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) x = self.global_average_pool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def _make_layer(self,in_channels, filters, num, stride=1): layers = [] block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True) layers.append(block_1) for i in range(1, num): layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False)) return nn.Sequential(*layers) def Senet(): cfg = { 
    'num':(3,4,6,3), 'classes': (10) } return senet(cfg) import time import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt def load_dataset(batch_size): train_set = torchvision.datasets.CIFAR10( root="data/cifar-10", train=True, download=True, transform=transforms.ToTensor() ) test_set = torchvision.datasets.CIFAR10( root="data/cifar-10", train=False, download=True, transform=transforms.ToTensor() ) train_iter = torch.utils.data.DataLoader( train_set, batch_size=batch_size, shuffle=True, num_workers=4 ) test_iter = torch.utils.data.DataLoader( test_set, batch_size=batch_size, shuffle=True, num_workers=4 ) return train_iter, test_iter def train(net, train_iter, criterion, optimizer, num_epochs, device, num_print, lr_scheduler=None, test_iter=None): net.train() record_train = list() record_test = list() for epoch in range(num_epochs): print("========== epoch: [{}/{}] ==========".format(epoch + 1, num_epochs)) total, correct, train_loss = 0, 0, 0 start = time.time() for i, (X, y) in enumerate(train_iter): X, y = X.to(device), y.to(device) output = net(X) loss = criterion(output, y) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() total += y.size(0) correct += (output.argmax(dim=1) == y).sum().item() train_acc = 100.0 * correct / total if (i + 1) % num_print == 0: print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}" \ .format(i + 1, len(train_iter), train_loss / (i + 1), \ train_acc, get_cur_lr(optimizer))) if lr_scheduler is not None: lr_scheduler.step() print("--- cost time: {:.4f}s ---".format(time.time() - start)) if test_iter is not None: record_test.append(test(net, test_iter, criterion, device)) record_train.append(train_acc) return record_train, record_test def test(net, test_iter, criterion, device): total, correct = 0, 0 net.eval() with torch.no_grad(): print("* test *") for X, y in test_iter: X, y = X.to(device), y.to(device) output = net(X) loss = criterion(output, y) total += y.size(0) correct += (output.argmax(dim=1) == y).sum().item() test_acc = 100.0 * correct / total print("test_loss: {:.3f} | test_acc: {:6.3f}%"\ .format(loss.item(), test_acc)) print("\n") net.train() return test_acc def get_cur_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def learning_curve(record_train, record_test=None): plt.style.use("ggplot") plt.plot(range(1, len(record_train) + 1), record_train, label="train acc") if record_test is not None: plt.plot(range(1, len(record_test) + 1), record_test, label="test acc") plt.legend(loc=4) plt.title("learning curve") plt.xticks(range(0, len(record_train) + 1, 5)) plt.yticks(range(0, 101, 5)) plt.xlabel("epoch") plt.ylabel("accuracy") plt.show() import torch.optim as optim BATCH_SIZE = 128 NUM_EPOCHS = 12 NUM_CLASSES = 10 LEARNING_RATE = 0.01 MOMENTUM = 0.9 WEIGHT_DECAY = 0.0005 NUM_PRINT = 100 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def main(): net = net = Senet() net = net.to(DEVICE) train_iter, test_iter = load_dataset(BATCH_SIZE) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD( net.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=True ) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) record_train, record_test = train(net, train_iter, criterion, optimizer, \ NUM_EPOCHS, DEVICE, NUM_PRINT, lr_scheduler, test_iter) learning_curve(record_train, record_test) if __name__ == '__main__': main() 

本文写的并不全面,se结构有好几种,大家可以看论文类比写出其他的结构

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/228543.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月16日 下午6:47
下一篇 2026年3月16日 下午6:47


相关推荐

  • Spring Cloud的架构[通俗易懂]

    Spring Cloud的架构[通俗易懂]SpringCloud架构图Eureka用于服务注册和发现,利用了客户端的服务发现,所以它内部需要Ribbon作为客户端负载均衡。Hystrix,客户端容错保护,服务熔断、请求缓存、请求合并、依赖隔离。Feign,声明式服务调用。Bus,消息总线,配合Config仓库修改的一种Stream实现,Dashboard,Hystrix仪表盘,监控集群模式和单点模式,其中集群模式…

    2022年5月13日
    40
  • 2018年SCI论文–整合GEO数据挖掘完整复现 七 :DAVID在线工具进行KEGG富集分析

    文章目录论文地址DAVID官网获得KEGG富集分析结果气泡图cytoscape软件绘制代谢通路网络图networkdatatabledata论文地址DAVID官网KEGG富集分析和GO富集分析方法一致,具体步骤见我上篇文章DAVID在线工具进行GO富集分析,这里主要展示可视化结果获得KEGG富集分析结果1.输入文件为所有差异表达基因列表2.选择GO富集分析结果时,我们点击“Path…

    2022年4月6日
    189
  • php mcrypt 遇到的错误问题

    php mcrypt 遇到的错误问题

    2021年9月7日
    47
  • OpenCode + GLM 安装和配置教程

    OpenCode + GLM 安装和配置教程

    2026年3月12日
    3
  • JavaScript对象详解,js对象属性的添加

    JavaScript对象详解,js对象属性的添加英文名 object 翻译成中文就是对象 用英语的角度来说 object 就是物体实体 即使他看不见摸不着 中文的对象指的是女朋友 在计算机中 用英语的角度理解对象 就是说放在内存里面的复杂数据集合 也叫做数据与方法的封装 是一种编程逻辑概念 函数是对数据与代码的封装 假如再把函数及函数外的数据进行封装 那就是 object 即对象

    2026年3月18日
    3
  • UNIX的常用命令「建议收藏」

    UNIX的常用命令「建议收藏」Unix常用命令介绍:  多命令行:“;”多行命令:“\”1、系统关闭reboot、halt/shutdown、poweroff2、passwd命令:修改系统用户密码passwd[username]3、su命令:切换系统用户su[-username]username为空表示root用户4、cat命令:将指定的文件在标准输出到显示器cat [-AbET] [文件名列表]-A      …

    2022年5月31日
    97

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号