ResNet34+Unet(可以直接用)

ResNet34+Unet(可以直接用)importtorchfromtorchimportnnimporttorch.nn.functionalasF#因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码classResidualBlock(nn.Module):def__init__(self,inchannel,outchannel,stride,shortcut=None):super(ResidualBlock,self).__init__(

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

import torch
from torch import nn
import torch.nn.functional as F


# 因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.basic = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1,
                      bias=False),  # 要采样的话在这里改变stride
            nn.BatchNorm2d(outchannel),  # 批处理正则化
            nn.ReLU(inplace=True),  # 激活
            nn.Conv2d(outchannel, outchannel, 3, 1, 1,
                      bias=False),  # 采样之后注意保持feature map的大小不变
            nn.BatchNorm2d(outchannel),
        )
        self.shortcut = shortcut

    def forward(self, x):
        out = self.basic(x)
        residual = x if self.shortcut is None else self.shortcut(x)  # 计算残差
        out += residual
        return nn.ReLU(inplace=True)(out)  # 注意激活


class Conv2dReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
    ):
        super(Conv2dReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        return x


class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
        )

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)

        x = self.conv1(x)
        x = self.conv2(x)

        return x


class SegmentationHead(nn.Sequential):
    def __init__(self,
                 in_channels=16,
                 out_channels=1,
                 kernel_size=3,
                 upsampling=1):
        conv2d = nn.Conv2d(in_channels,
                           out_channels,
                           kernel_size=kernel_size,
                           padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(
            scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)


# ResNet类
class Resnet34(nn.Module):
    def __init__(self, inchannels):
        super(Resnet34, self).__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(inchannels, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
        )  # 开始的部分
        self.body = self.makelayers([3, 4, 6, 3])  # 具有重复模块的部分
        in_channels = [512, 256, 128, 128, 32]
        skip_channels = [256, 128, 64, 0, 0]
        out_channels = [256, 128, 64, 32, 16]
        blocks = [
            DecoderBlock(in_ch, skip_ch,
                         out_ch) for in_ch, skip_ch, out_ch in zip(
                             in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)
        self.seg = SegmentationHead()

    def makelayers(self, blocklist):  # 注意传入列表而不是解列表
        self.layers = []
        for index, blocknum in enumerate(blocklist):
            if index != 0:
                shortcut = nn.Sequential(
                    nn.Conv2d(64 * 2**(index - 1),
                              64 * 2**index,
                              1,
                              2,
                              bias=False),
                    nn.BatchNorm2d(64 * 2**index))  # 使得输入输出通道数调整为一致
                self.layers.append(
                    ResidualBlock(64 * 2**(index - 1), 64 * 2**index, 2,
                                  shortcut))  # 每次变化通道数时进行下采样
            for i in range(0 if index == 0 else 1, blocknum):
                self.layers.append(
                    ResidualBlock(64 * 2**index, 64 * 2**index, 1))
        return nn.Sequential(*self.layers)

    def forward(self, x):
        self.features = []
        # 下采样
        # x = self.pre(x)
        for i, l in enumerate(self.pre):
            x = l(x)
            if i == 2:
                self.features.append(x)

        print("y=", len(self.features))
        for i, l in enumerate(self.body):
            if i == 3 or i == 7 or i == 13:
                self.features.append(x)
            x = l(x)
        skips = self.features[::-1]

        # skips = self.features[1:]

        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        x = self.seg(x)
        return x



四次Skipconnect分别在:Maxpool前;另外三次在通道数变化前。
上采样combine时采用的是插值(nn.functionnal.interpolate)。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185682.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 对猴子摘香蕉问题给出产生式系统描述_猴子接香蕉的编程

    对猴子摘香蕉问题给出产生式系统描述_猴子接香蕉的编程一个房间里,天花板上挂有一串香蕉,有一只猴子可在房间里任意活动(到处走动,推移箱子,攀登箱子等)。设房间里还有一只可被猴子移动的箱子,且猴子登上箱子时才能摘到香蕉,问猴子在某一状态下(设猴子位置为A,香蕉位置在B,箱子位置为C),如何行动可摘取到香蕉2.1猴子摘香蕉问题PEAS性能环境执行器感知器猴子站在箱香蕉MoveSite子上摘到香箱子ClimbHold蕉房间(a,b,c)PushOnGraspHangJump2.2定义谓词Site(x,w):物体x的位置是wHold(z):z手中拿着香蕉On(z):z

    2022年9月26日
    3
  • 实验三 简单结构局域网组建与配置

    实验三 简单结构局域网组建与配置实验目的 了解一个局域网的基本组成 掌握一个局域网设备互通所需的基本配置 掌握报文的基本传输过程 实验任务 1 根据所认识的设备设计一个简单的局域网并在仿真环境中画出其逻辑拓扑 2 配置拓扑中的各设备连通所需的参数 3 在模拟模式下进行包传输路径跟踪测试 建议实验学时 2 学时 实验背景 nbsp 简单的局域网主要由交换机 HUB PC 等设备组建 他们的连接和配置比

    2025年10月6日
    2
  • 思科和华为交换机命令的区别(思科交换机和华为交换机的异同)

    思科和华为交换机常用命令对比一、调试命令思科:Switch#showrun显示所有配置命令Switch#showipinterbrief显示所有接口状态Switch#showvlanbrief显示所有VLAN的信息Switch#showversion显示版本信息华为:[Quidway]discur显示所有配置命令[Quidway…

    2022年4月16日
    76
  • 用Excel处理笛卡尔积

    用Excel处理笛卡尔积    工作中遇到需要处理笛卡尔积的需求,用数据库只需把需要做笛卡尔积的各列进行外链接就可以了,想到Excel应该可以处理这样的需求,就百度学习了一下,但还是看不太懂,下面只是依葫芦画瓢做了一遍,记录一下。1、构建两列数据,如下图:2、构建D列辅助列,E列为用index函数处理A列后的数据,每个值的重复次数为B列的数值行数。3、用index()函数处理B列:…

    2022年7月27日
    121
  • 图的五种最短路径算法

    图的五种最短路径算法本文总结了图的几种最短路径算法的实现:深度或广度优先搜索算法,费罗伊德算法,迪杰斯特拉算法,Bellman-Ford算法。1)深度或广度优先搜索算法(解决单源最短路径)从起点开始访问所有深度遍历路径或广度优先路径,则到达终点节点的路径有多条,取其中路径权值最短的一条则为最短路径。下面是核心代码:voiddfs(intcur,intdst){if(minpath&lt;dst)r…

    2022年6月4日
    47
  • biztalk什么意思_aide教程网

    biztalk什么意思_aide教程网BizTalk开发系列(二十六) 使用Web Service

    2022年4月21日
    61

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号