resnet源码pytorch_pytorch conv1d

resnet源码pytorch_pytorch conv1d#Pytorch 0.4.0 ResNet34实现cifar10分类.#@Time:2018/6/17#@Author:xfLiimporttorchvisionastvimporttorchastimporttorchvision.transformsastransformsfromtorchimportnnfromtorch.utils.da…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用


# Pytorch 0.4.0 ResNet34实现cifar10分类.
# @Time: 2018/6/17
# @Author: xfLi

import torchvision as tv
import torch as t
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
t.set_num_threads(8)


class ResidualBloak(nn.Module):
    #残差块
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBloak, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel))
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

class ResNet34(nn.Module):
    #  实现主module:ResNet34  
    #  ResNet34 包含多个layer,每个layer又包含多个residual block  
    #  用子module来实现residual block,用_make_layer函数来实现layer 
    def __init__(self, num_classes):
        super(ResNet34, self).__init__()
        #前几层图像转换
        self.pre = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1))

        # 重复的layer,分别有3,4,6,3个residual block
        self.layer1 = self._make_layer(16, 16, 3, stride=1)
        self.layer2 = self._make_layer(16, 32, 4, stride=1)
        self.layer3 = self._make_layer(32, 64, 6, stride=1)
        self.layer4 = self._make_layer(64, 64, 3, stride=1)
        #分类用的全连接
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        #构建layer,包含多个residual block
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm2d(outchannel))
        layer = []
        layer.append(ResidualBloak(inchannel, outchannel, stride, shortcut))
        for i in range(1, block_num):
            layer.append(ResidualBloak(outchannel, outchannel))
        return nn.Sequential(*layer)

    def forward(self, x):
        x = self.pre(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 7)
        x = x.view(x.size(0), -1)
        return self.fc(x)

def getData(): # 定义对数据的预处理  
    transform = transforms.Compose([
        transforms.Resize(40),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
        transforms.ToTensor()])
    #训练集
    trainset = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
    trainset_loader = DataLoader(trainset, batch_size=4, shuffle=True)
    #测试集
    testset = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)
    testset_loader = DataLoader(testset, batch_size=4, shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainset_loader, testset_loader, classes

def train(): #训练
    trainset_loader, testset_loader, _ = getData() #获取数据
    net = ResNet34(10)
    print(net)
    criterion = nn.CrossEntropyLoss()
    optimizer = t.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #优化器

    for epoch in range(1):
        for step, (inputs,labels) in enumerate(trainset_loader):
            optimizer.zero_grad() #梯度清零
            output = net(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            if step % 10 ==9:
                acc = test(net, testset_loader)
                print('Epoch', epoch, '|step ', step, 'loss: %.4f' %loss.item(), 'test accuracy:%.4f' %acc)
    print('Finished Training')
    return net

def test(net, testdata): #测试集
    correct, total = .0, .0
    for inputs, label in testdata:
        net.eval()
        output = net(inputs)
        _, predicted = t.max(output, 1) #分类结果
        total += label.size(0)
        correct += (predicted == label).sum()
    return float(correct) / total

if __name__ == '__main__':
    net = train()








版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185357.html原文链接:https://javaforall.net

(0)
上一篇 2022年10月6日 下午12:16
下一篇 2022年10月6日 下午12:16


相关推荐

  • Java 继承、多态与类的复用

    Java 继承、多态与类的复用本文结合Java的类的复用对面向对象两大特征继承和多态进行了全面的介绍。首先,我们介绍了继承的实质和意义,并探讨了继承,组合和代理在类的复用方面的异同。紧接着,我们根据继承引入了多态,介绍了它的实现机制和具体应用。此外,为了更好地理解继承和多态,我们对final关键字进行了全面的介绍。在此基础上,我们介绍了Java中类的加载及初始化顺序。最后,我们对面向对象设计中三个十分重要的概念-重载、覆盖与隐藏进行了详细的说明。

    2022年7月8日
    18
  • MySQL数据库:分区Partition

    MySQL数据库:分区Partition

    2021年4月9日
    168
  • Jupyter Notebook 的快捷键

    Jupyter Notebook 的快捷键JupyterNoteb 的快捷键 JupyterNoteb 有两种键盘输入模式 编辑模式 允许你往单元中键入代码或文本 这时的单元框线是绿色的 命令模式 键盘输入运行程序命令 这时的单元框线是灰色 命令模式 按键 Esc 开启 Enter 转入编辑模式 Shift Enter 运行本单元 选中下个单元 Ctrl Enter 运行本单元 Alt Enter

    2026年3月19日
    2
  • 斯坦福2025年HAI报告出炉 国产大模型仅讯飞星火入围Mix-Eval前十

    斯坦福2025年HAI报告出炉 国产大模型仅讯飞星火入围Mix-Eval前十

    2026年3月14日
    3
  • Eclipse导入Java项目报错

    Eclipse导入Java项目报错Eclipse 中导入 Java 项目出现 Noprojectsar 原因 这其实是因为项目中缺少了两个文件 classpath 文件和 project 文件 所以 eclipse 找不到你的项目了 解决方案 在 Eclipse 中再新建一个新的项目 项目的类型和名称和导入的项目名一样 然后再新建的项目目录下 找到 classpath 文件和 project 文件 把它们复制到想要导入的项目中 最后就可以成功导入 不会报错了

    2026年3月19日
    1
  • springboot启动成功访问404_springboot启动执行

    springboot启动成功访问404_springboot启动执行今天在做一个springboot项目的时候,是接着别人的项目写的,写完之后想做一下测试,于是就启动了springboot,然后在放问的时候,一直包404的错误,然后百度了一下网上给的方法,包括注解使用@RestController,然后去除掉方法@RequestMapping(value="/add")中的“value=”,这个方法对我无用,因为我的项目之前就是用的@RestC…

    2022年10月13日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号