ResNet34+Unet(可以直接用)

ResNet34+Unet(可以直接用)importtorchfromtorchimportnnimporttorch.nn.functionalasF#因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码classResidualBlock(nn.Module):def__init__(self,inchannel,outchannel,stride,shortcut=None):super(ResidualBlock,self).__init__(

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

import torch
from torch import nn
import torch.nn.functional as F


# 因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.basic = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1,
                      bias=False),  # 要采样的话在这里改变stride
            nn.BatchNorm2d(outchannel),  # 批处理正则化
            nn.ReLU(inplace=True),  # 激活
            nn.Conv2d(outchannel, outchannel, 3, 1, 1,
                      bias=False),  # 采样之后注意保持feature map的大小不变
            nn.BatchNorm2d(outchannel),
        )
        self.shortcut = shortcut

    def forward(self, x):
        out = self.basic(x)
        residual = x if self.shortcut is None else self.shortcut(x)  # 计算残差
        out += residual
        return nn.ReLU(inplace=True)(out)  # 注意激活


class Conv2dReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
    ):
        super(Conv2dReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        return x


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
        )

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)

        x = self.conv1(x)
        x = self.conv2(x)

        return x


class SegmentationHead(nn.Sequential):
    def __init__(self,
                 in_channels=16,
                 out_channels=1,
                 kernel_size=3,
                 upsampling=1):
        conv2d = nn.Conv2d(in_channels,
                           out_channels,
                           kernel_size=kernel_size,
                           padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(
            scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)


# ResNet类
class Resnet34(nn.Module):
    def __init__(self, inchannels):
        super(Resnet34, self).__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(inchannels, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
        )  # 开始的部分
        self.body = self.makelayers([3, 4, 6, 3])  # 具有重复模块的部分
        in_channels = [512, 256, 128, 128, 32]
        skip_channels = [256, 128, 64, 0, 0]
        out_channels = [256, 128, 64, 32, 16]
        blocks = [
            DecoderBlock(in_ch, skip_ch,
                         out_ch) for in_ch, skip_ch, out_ch in zip(
                             in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)
        self.seg = SegmentationHead()

    def makelayers(self, blocklist):  # 注意传入列表而不是解列表
        self.layers = []
        for index, blocknum in enumerate(blocklist):
            if index != 0:
                shortcut = nn.Sequential(
                    nn.Conv2d(64 * 2**(index - 1),
                              64 * 2**index,
                              1,
                              2,
                              bias=False),
                    nn.BatchNorm2d(64 * 2**index))  # 使得输入输出通道数调整为一致
                self.layers.append(
                    ResidualBlock(64 * 2**(index - 1), 64 * 2**index, 2,
                                  shortcut))  # 每次变化通道数时进行下采样
            for i in range(0 if index == 0 else 1, blocknum):
                self.layers.append(
                    ResidualBlock(64 * 2**index, 64 * 2**index, 1))
        return nn.Sequential(*self.layers)

    def forward(self, x):
        self.features = []
        # 下采样
        # x = self.pre(x)
        for i, l in enumerate(self.pre):
            x = l(x)
            if i == 2:
                self.features.append(x)

        print("y=", len(self.features))
        for i, l in enumerate(self.body):
            if i == 3 or i == 7 or i == 13:
                self.features.append(x)
            x = l(x)
        skips = self.features[::-1]

        # skips = self.features[1:]

        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        x = self.seg(x)
        return x



四次Skipconnect分别在:Maxpool前;另外三次在通道数变化前。
上采样combine时采用的是插值(nn.functionnal.interpolate)。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185682.html原文链接:https://javaforall.net

(0)
上一篇 2022年10月5日 下午2:00
下一篇 2022年10月5日 下午2:00


相关推荐

  • 堆和栈的概念和区别

    堆和栈的概念和区别在说堆和栈之前 我们先说一下 JVM 虚拟机 内存的划分 nbsp nbsp nbsp nbsp nbsp Java 程序在运行时都要开辟空间 任何软件在运行时都要在内存中开辟空间 Java 虚拟机运行时也是要开辟空间的 JVM 运行时在内存中开辟一片内存区域 启动时在自己的内存区域中进行更细致的划分 因为虚拟机中每一片内存处理的方式都不同 所以要单独进行管理 nbsp nbsp nbsp nbsp nbsp JVM 内存的划分有五片 nbsp nbsp nbsp nbsp nbsp 1 nbsp nbsp nbsp 寄存器

    2026年3月19日
    1
  • pycharm2018激活码 pycharm激活码_软件一键无痕看怎么使用的

    pycharm2018激活码 pycharm激活码_软件一键无痕看怎么使用的PyCharm激活码最新破解教程,Mac版激活至2299年,PyCharm激活码2021.3.3

    2022年4月20日
    159
  • 消息队列 rabbitmq面试题(中间件面试题)

    文章目录为什么使用MQ?MQ的优点消息队列有什么优缺点?RabbitMQ有什么优缺点?你们公司生产环境用的是什么消息中间件?Kafka、ActiveMQ、RabbitMQ、RocketMQ有什么优缺点?MQ有哪些常见问题?如何解决这些问题?什么是RabbitMQ?rabbitmq的使用场景RabbitMQ基本概念RabbitMQ的工作模式如何保证RabbitMQ消息的顺序性?消息如何分发?消…

    2022年4月14日
    62
  • pycharm如何安装python环境_pycharm怎么安装「建议收藏」

    pycharm如何安装python环境_pycharm怎么安装「建议收藏」安装方法:1、安装配置好Python环境;2、从官网下载pycharm安装程序;3、直接双击下载好的exe文件,进入安装向导界面,按照指示一步步操作;4、点击Install进行安装,等待安装完成后,点击Finish结束安装即可。本教程操作环境:windows7系统、Python3.5.2版本、DellG3电脑。首先我们来安装python1、首先进入网站下载:点击打开链接(或自己输入网址http…

    2022年8月27日
    12
  • 来客在线客服系统源码 支持一键安装

    来客在线客服系统源码 支持一键安装简介:来客客服源码/带完整文字教程/一键安装好友分享的,状态比流通版本还是要好很多。不支持前端商户注册,这个大家应该也都用过了,有喜欢的,自己拿去吧,东西如下。网盘下载地址:http://kekewl.org/D11LdBKXP7L0图片:…

    2022年7月19日
    44
  • 2021年Vue最常见的面试题以及答案(面试必过)[通俗易懂]

    2021年Vue最常见的面试题以及答案(面试必过)[通俗易懂]这里写目录标题对MVVM的理解?Vue数据双向绑定原理Vue的响应式原理vue中组件的data为什么是一个函数vue中created与mounted区别Vue中computed与method的区别Vue中watch用法详解Vue中常用的一些指令说说vue的生命周期对MVVM的理解?MVVM由Model、View、ViewModel三部分构成,Model层代表数据模型,也可以在Model中定义数据修改和操作的业务逻辑;View代表UI组件,它负责将数据模型转化成UI展现出来;ViewMode

    2022年5月31日
    112

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号