使用PyTorch搭建ResNet34网络

使用PyTorch搭建ResNet34网络ResNet34 网络结构先上图参照 ResNet18 的搭建 由于 34 层和 18 层几乎相同 叠加卷积单元数即可 所以没有写注释 具体可以参考我的 ResNet18 搭建中的注释 ResNet34 的训练部分也可以参照 importtorchi nnasnnfromto nnimportfunc nn Module def init self in channel out chann

ResNet34网络结构

先上图
在这里插入图片描述

在这里插入图片描述

参照ResNet18的搭建,由于34层和18层几乎相同,叠加卷积单元数即可,所以没有写注释,具体可以参考我的ResNet18搭建中的注释,ResNet34的训练部分也可以参照。

使用PyTorch搭建ResNet18网络

ResNet34的model.py模型部分

import torch import torch.nn as nn from torch.nn import functional as F class CommonBlock(nn.Module): def __init__(self, in_channel, out_channel, stride): super(CommonBlock, self).__init__() self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channel) self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channel) def forward(self, x): identity = x x = F.relu(self.bn1(self.conv1(x)), inplace=True) x = self.bn2(self.conv2(x)) x += identity return F.relu(x, inplace=True) class SpecialBlock(nn.Module): def __init__(self, in_channel, out_channel, stride): super(SpecialBlock, self).__init__() self.change_channel = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride[0], padding=0, bias=False), nn.BatchNorm2d(out_channel) ) self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride[0], padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channel) self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride[1], padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channel) def forward(self, x): identity = self.change_channel(x) x = F.relu(self.bn1(self.conv1(x)), inplace=True) x = self.bn2(self.conv2(x)) x += identity return F.relu(x, inplace=True) class ResNet34(nn.Module): def __init__(self, classes_num): super(ResNet34, self).__init__() self.prepare = nn.Sequential( nn.Conv2d(3, 64, 7, 2, 3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(3, 2, 1) ) self.layer1 = nn.Sequential( CommonBlock(64, 64, 1), CommonBlock(64, 64, 1), CommonBlock(64, 64, 1) ) self.layer2 = nn.Sequential( SpecialBlock(64, 128, [2, 1]), CommonBlock(128, 128, 1), CommonBlock(128, 128, 1), CommonBlock(128, 128, 1) ) self.layer3 = nn.Sequential( SpecialBlock(128, 256, [2, 1]), CommonBlock(256, 256, 1), CommonBlock(256, 256, 1), CommonBlock(256, 256, 1), CommonBlock(256, 256, 1), CommonBlock(256, 256, 1) ) self.layer4 = nn.Sequential( SpecialBlock(256, 512, [2, 1]), CommonBlock(512, 512, 1), CommonBlock(512, 512, 1) ) self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) self.fc = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(256, classes_num) ) def forward(self, x): x = self.prepare(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.pool(x) x = x.reshape(x.shape[0], -1) x = self.fc(x) return x 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/233959.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • python encode和decode的区别_encode和decode的区别

    python encode和decode的区别_encode和decode的区别字符串在Python内部的表示是unicode编码,因此,在做编码转换时,通常需要以unicode作为中间编码,即先将其他编码的字符串解码(decode)成unicode,再从unicode编码(encode)成另一种编码。decode的作用是将其他编码的字符串转换成unicode编码,如str1.decode(‘gb2312’),表示将gb2312编码的字符串str1转换成unicode编码。e…

    2022年10月7日
    2
  • realsense深度图像保存方法

    realsense深度图像保存方法一般使用realsense时会保存视频序列,当保存深度图像时,需要注意保存的图像矩阵的格式,不然可能造成深度值的丢失。在众多图像库中,一般会使用opencv中的imwrite()函数进行深度图像的保存。一般深度图像中深度值的单位是mm,因此一般使用np.uint16作为最终数据格式保存。例子:importnumpyasnpimportcv2deffun1(…

    2022年4月25日
    31
  • Java 数组在内存中的存储 数组的常见操作

    Java 数组在内存中的存储 数组的常见操作Java 数组在内存中的存储数组的常见操作

    2025年6月29日
    2
  • leetcode516_leetcode46

    leetcode516_leetcode46Givenacollectionofnumbers,returnallpossiblepermutations.Forexample,[1,2,3] havethefollowingpermutations:[1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2],and [3,2,1].思路:递归咯c

    2022年9月20日
    2
  • python和c++哪个好_python地名识别

    python和c++哪个好_python地名识别Ctrl+N按文件名搜索py文件ctrl+n可以搜索py文件勾选上面这个框可以搜索工程以外的文件Ctrl+shift+N按文件名搜索所有类型的文件Ctrl+shift+N可以搜索py文件,也可以搜索其它类型的文件。除了搜索不同类型的文件,Ctrl+shift+N还有一个强大之处是可以搜索路径,只需要在你搜索的词前面或后面加上/ctrl+shift+f全局字符串搜索这种搜索的名字叫做”fin…

    2022年8月28日
    3
  • Navicat 15 for MySQL激活码【在线注册码/序列号/破解码】[通俗易懂]

    Navicat 15 for MySQL激活码【在线注册码/序列号/破解码】,https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月18日
    171

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号