pytorch(8)– resnet101 迁移学习记录

pytorch(8)– resnet101 迁移学习记录一、前言本篇记录使用pytorch官方resnet101实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training+fine-tuning(预训练+调参)的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步(1)把预训练模型当做特征提取器:TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

一、前言

    本篇记录使用 pytorch 官方 resnet101 实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training + fine-tuning(预训练+调参) 的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步

 (1)把预训练模型当做特征提取器: TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或者把最后一层直接去掉换成自己的分类器, 剩下的全部网络结构当做一个特征提取器。
 (2)fine-tuning: 通常来说,直接把预训练模型来用效果不一定足够好,因此需要进行fine-tuning(微调)。fine-tuning需要冻结网络的前几层参数,只更新网络结构的后面几层和最后的全连接层,这样效果会更好。
 (3) Learning rate: 在迁移学习的微调过程中一般不建议使用过大的学习率,通常来说1e-5是比较合适的选择

二、代码

resnet101 官网定义

import torch
from torchvision.models.resnet import ResNet, Bottleneck
 
def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location='cpu')  # 加载模型文件,pt, pth 文件都可以
        model.load_state_dict( checkpoint )
    return model

然后使用resnet101,加载官方预训练模型,再修改最后全连接层,训练过程只对最后全连接层做训练

    #初始化net,训练和验证都需要net
    
    net = resnet101(pretrained=True)
    net.fc = torch.nn.Sequential(torch.nn.Linear(2048, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 257 ))  

    net = net.to(device)   
    
    #初始化optimizer,只有train时使用
    optimizer = optim.SGD( net.fc.parameters(), lr=1e-5, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20 , gamma=0.5) 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185265.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • [数据库] 一文搞懂case when所有使用场景「建议收藏」

    [数据库] 一文搞懂case when所有使用场景「建议收藏」前几天,为了给产品分析当前用户数据结构,写sql的时候使用到了casewhen,今天来总结一下casewhen的使用方法,以此为戒,感觉写的不好请拍砖,感觉写的还可以,给哥们点个赞,或者回复一下,让我意识到我不是一个人在战斗,好了废话不多说了,进入正题。关于casewhen的使用情况,我总结下来有三种,第一、等值转换,第二、范围转换,第三、列转行操作。等值转换咱们在设计数据库的…

    2025年9月17日
    5
  • winhex 数据恢复_win10文件恢复软件

    winhex 数据恢复_win10文件恢复软件数据恢复分类:硬恢复和软恢复。所谓硬恢复就是硬盘出现物理性损伤,比如有盘体坏道、电路板芯片烧毁、盘体异响,等故障,由此所导致的普通用户不容易取出里面数据,那么我们将它修好,同时又保留里面的数据或后来恢

    2022年8月5日
    5
  • 有意思的批处理

    有意思的批处理

    2021年7月26日
    60
  • filter pitcher是什么意思_EncodingFilter

    filter pitcher是什么意思_EncodingFilterorg.apache.struts2.dispatcher.FilterDispatcher是Struts2的主要的Filter,负责四个方面的功能:       (1)执行Actions       (2)清除ActionContext       (3)维护静态内容       (4)清除request生命周期内的XWork的interceptors   另注:该

    2022年8月16日
    5
  • Hadoop生态系统简介

    Hadoop生态系统简介Hadoop生态系统主要包括:Hive、HBase、Pig、Sqoop、Flume、ZooKeeper、Mahout、Spark、Storm、Shark、Phoenix、Tez、Ambari。Hive:用于Hadoop的一个数据仓库系统,它提供了类似于SQL的查询语言,通过使用该语言可以方便地进行数据汇总,特定查询以及分析存放在Hadoop兼容文件系统中的大数据。HBase:一种分布的、可

    2022年5月19日
    39
  • SQL Server 下载安装教程

    SQL Server 下载安装教程SQLServer2017下载安装教程第一步:打开浏览器,在浏览的搜索框中我们输入“SQLServer”。如图,会匹配出中文两条微软官方下载页面(一个页面内容是英文、一个页面内容是中文)。这里我们以中文的为例。第二步:点击进入下载页面后,可以看到如图所示页面,我们不要着急点击下载,因为这些SQLServer只能试用180天(大家从介绍中可以看到)。第三步:我们将网页下滑,可以看到…

    2022年4月29日
    54

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号