libtorch-resnet18

libtorch-resnet18与大家分享一下自己在学习使用libtorch搭建神经网络时学到的一些心得和例子,记录下来供大家参考首先我们要参考着pytorch版的resnet来搭建,这样我们可以省去不必要的麻烦,上代码:1、首先是pytorch版残差模块classResidualBlock(nn.Module):def__init__(self,inchannel,outchannel,stride=1,shortcut=None):super(ResidualBlock,self).__

大家好,又见面了,我是你们的朋友全栈君。

与大家分享一下自己在学习使用libtorch搭建神经网络时学到的一些心得和例子,记录下来供大家参考
首先我们要参考着pytorch版的resnet来搭建,这样我们可以省去不必要的麻烦,上代码:
1、首先是pytorch版残差模块

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1),
            nn.BatchNorm2d(outchannel)
        )
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

2、libtorch版残差模块
因为是用c++搭建的,所以先创建头文件
2.1残差模块头文件(声明)

//重载函数
inline torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
	int64_t stride = 1, int64_t padding = 0, int groups = 1, bool with_bias = true) { 
   
	torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size);
	conv_options.stride(stride);
	conv_options.padding(padding);
	conv_options.bias(with_bias);
	conv_options.groups(groups);
	return conv_options;
}
//残差模块声明
class Block_ocrImpl : public torch::nn::Module { 
   
public:
    Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
		torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
	torch::Tensor forward(torch::Tensor x);
	torch::nn::Sequential downsample{ 
    nullptr };
private:
	bool is_basic = true;
	int64_t stride = 1;
	torch::nn::Conv2d conv1{ 
    nullptr };
	torch::nn::BatchNorm2d bn1{ 
    nullptr };
	torch::nn::Conv2d conv2{ 
    nullptr };
	torch::nn::BatchNorm2d bn2{ 
    nullptr };
	torch::nn::Conv2d conv3{ 
    nullptr };
	torch::nn::BatchNorm2d bn3{ 
    nullptr };
};
TORCH_MODULE(Block_ocr);

2.2残差模块定义
这里我们要在头文件里面写一个卷积的重载函数,省去以后重复写的工作,我把它放在了2的头文件里面

//残差模块定义
Block_ocrImpl::Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_,
    torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
{ 
   
    downsample = downsample_;
    stride = stride_;
    int width = int(planes * (base_width / 64.)) * groups;

    conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
    bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    is_basic = _is_basic;
    if (!is_basic) { 
   
        conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
        conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
        conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
        bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
    }

    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("conv2", conv2);
    register_module("bn2", bn2);
    if (!is_basic) { 
   
        register_module("conv3", conv3);
        register_module("bn3", bn3);
    }

    if (!downsample->is_empty()) { 
   
        register_module("downsample", downsample);
    }
}
//残差前向传播
torch::Tensor Block_ocrImpl::forward(torch::Tensor x) { 
   
    torch::Tensor residual = x.clone();

    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);

    x = conv2->forward(x);
    x = bn2->forward(x);

    if (!is_basic) { 
   
        x = torch::relu(x);
        x = conv3->forward(x);
        x = bn3->forward(x);
    }

    if (!downsample->is_empty()) { 
   
        residual = downsample->forward(residual);
    }

    x += residual;
    x = torch::relu(x);

    return x;
}

3、pytorch版resnet主函数

class ResNet18(nn.Module):
    def __init__(self,nc):
        super(ResNet18, self).__init__()
        ###网络输入部分由一个7x7stride=2的卷积核和一个3x3stride=2的最大池化组成
        self.pre = nn.Sequential(
            nn.Conv2d(nc, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
        )
        ###网络中间卷积部分,通过中间3x3的卷积堆叠来实现信息的提取,下面的2代表bolck的重复堆叠次数
        self.layer1 = self._make_layer(64, 128, 1)

        self.layer2 = self._make_layer(128, 256, 2, stride=(2, 1))

        self.layer3 = self._make_layer(256, 512, 5, stride=(2, 1))

        self.layer4 = self._make_layer(512, 512, 3, stride=(2, 1))


    def _make_layer(self, inchannel, outchannel, block_num, stride=(1, 1)):
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride),
            nn.BatchNorm2d(outchannel)
        )
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))  # 改变通道数量
        for i in range(1, block_num + 1):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)

    ###规定网络数据流向
    def forward(self, x):
        x = self.pre(x)  ###[2,3,32,280]--->[2,64,8,70]
        x = self.layer1(x)  ###[2,64,8,70]
        x = self.layer2(x)  ###[2,128,4,35]
        x = self.layer3(x)  ###[2,256,2,17]
        x = self.layer4(x)  ###[2,512,1,8]
        return x

4、libtorch版主函数
和残差模块一样,分为头文件(.h)和源文件(.cpp)
先写头文件,还是仿照pytorch版的来写,这样我们可以避免很多麻烦
4.1主函数头文件(声明)

//主函数声明
class ResNet_ocrImpl : public torch::nn::Module { 
   
public:
    ResNet_ocrImpl(/*std::vector<int> layers, int num_classes = 1000,*/ std::string model_type = "resnet18",
        int groups = 1, int width_per_group = 64);
    torch::Tensor forward(torch::Tensor x);
    std::vector<torch::Tensor> features(torch::Tensor x);
    torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1);
private:
    int expansion = 1; bool is_basic = true;
    int64_t inplanes = 64; int groups = 1; int base_width = 64;
    torch::nn::Conv2d conv1{ 
    nullptr };
    torch::nn::BatchNorm2d bn1{ 
    nullptr };
    torch::nn::Sequential layer1{ 
    nullptr };
    torch::nn::Sequential layer2{ 
    nullptr };
    torch::nn::Sequential layer3{ 
    nullptr };
    torch::nn::Sequential layer4{ 
    nullptr };
};
TORCH_MODULE(ResNet_ocr);

4.2主函数定义

//先定义层函数_make_layer,这里也是参照pytorch写的
torch::nn::Sequential ResNet_ocrImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) { 
   

    torch::nn::Sequential downsample;
    if (stride != 1 || inplanes != planes * expansion) { 
   
        downsample = torch::nn::Sequential(
            torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
            torch::nn::BatchNorm2d(planes * expansion)
        );
    }
    torch::nn::Sequential layers;
    layers->push_back(Block_ocr(inplanes, planes, stride, downsample, groups, base_width, is_basic));
    inplanes = planes * expansion;
    for (int64_t i = 1; i < blocks; i++) { 
   
        layers->push_back(Block_ocr(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width, is_basic));
    }

    return layers;
}
//然后定义主函数
ResNet_ocrImpl::ResNet_ocrImpl(/*std::vector<int> layers, int num_classes,*/ std::string model_type, int _groups, int _width_per_group)
{ 
   
    if (model_type != "resnet18" && model_type != "resnet34")
    { 
   
        expansion = 4;
        is_basic = false;
    }
    groups = _groups;
    base_width = _width_per_group;
    conv1 = torch::nn::Conv2d(conv_options(1, 64, 7, 2, 3, 1, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
    layer1 = torch::nn::Sequential(_make_layer(64, 2/*layers[0]*/));
    layer2 = torch::nn::Sequential(_make_layer(128, 2/*layers[1]*/, 2));
    layer3 = torch::nn::Sequential(_make_layer(256,2 /*layers[2]*/, 2));
    layer4 = torch::nn::Sequential(_make_layer(512, 2/*layers[3]*/, 2));
    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("layer1", layer1);
    register_module("layer2", layer2);
    register_module("layer3", layer3);
    register_module("layer4", layer4);
    for (auto& module : modules(/*include_self=*/false)) { 
   
        				if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) { 
   
        					torch::nn::init::kaiming_normal_(
        						M->weight,
        						/*a=*/0,
        						torch::kFanOut,
        						torch::kReLU);
        				}
        				else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) { 
   
        					torch::nn::init::constant_(M->weight, 1);
        					torch::nn::init::constant_(M->bias, 0);
        				}
        			}
        	
}

//resnet主函数-前向传播
torch::Tensor  ResNet_ocrImpl::forward(torch::Tensor x) { 
   
    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);
    x = torch::max_pool2d(x, 3, 2, 1);

    x = layer1->forward(x);
    x = layer2->forward(x);
    x = layer3->forward(x);
    x = layer4->forward(x);
    return x;
}

以上就是;libtorch版的resnet18 网络,完全使用c++搭建的,由于我用resnet需要和别的网络拼接,所以fc层和softmax层给删了,有需要的可以自己填上。这里也是参考一位github大神的手法来写的。
科技无罪、知识无罪,我们要做知识的传播者!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/142669.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • vscode控制台输出乱码_python运行出现乱码

    vscode控制台输出乱码_python运行出现乱码问题:在VSCode中使用RunCode运行python程序会出现以下乱码:原因:没有设置python的编码格式问题解决:在file—Preferences——Settings中找到RunCodeconfiguration——Editinsettings.json,如下在打开的json文件中查看是否有”code-runner.executorMap”:{…}项,没有的话就手动加进去,然后找到”python”:”python..

    2025年6月10日
    0
  • OV7725的帧率和PCLK寄存器设置[通俗易懂]

    OV7725的帧率和PCLK寄存器设置[通俗易懂]一、OV7725的PCLK的改变和以下几个寄存器有关:    1:OX0D;2:0X11—————————————————————————————————————————————————

    2022年9月23日
    0
  • coreos docker 尝新奇

    coreos docker 尝新奇

    2022年1月21日
    33
  • opencv角点检测学习总结[通俗易懂]

    opencv角点检测学习总结[通俗易懂]学习opencv角点检测如果一个点在两个正交方向上都有明显的导数,则我们认为此点更倾向于是独一无二的,所以许多可跟踪的特征点都是角点。一下为角点检测中用到的一些函数cvGoodFeaturesToTrack采用Shi和Tomasi提出的方法,先计算二阶导数,再计算特征值,它返回满足易于跟踪的定义的一系列点。voidcvGoodFeaturesToTrack(

    2022年8月30日
    3
  • nfs共享使用方法

    nfs共享使用方法

    2022年3月11日
    35
  • vue x 兼容iphone_作为前端你必须知道的iPhoneX适配

    ​1.iPhoneX的介绍屏幕尺寸我们熟知的iPhone系列开发尺寸概要如下:△iPhone各机型的开发尺寸转化成我们熟知的像素尺寸:△每个机型的多维度尺寸倍图其实就是像素尺寸和开发尺寸的倍率关系,但这只是外在的表现。倍图核心的影响因素在于PPI(DPI),了解屏幕密度与各尺寸的关系有助于我们深度理解倍率的概念:《基础知识学起来!为设计师量身打造的DPI指南》iPhone8在本次升级中,屏…

    2022年4月13日
    46

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号