pytorch(8)– resnet101 迁移学习记录

pytorch(8)– resnet101 迁移学习记录一、前言本篇记录使用pytorch官方resnet101实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training+fine-tuning(预训练+调参)的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步(1)把预训练模型当做特征提取器:TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

一、前言

    本篇记录使用 pytorch 官方 resnet101 实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training + fine-tuning(预训练+调参) 的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步

 (1)把预训练模型当做特征提取器: TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或者把最后一层直接去掉换成自己的分类器, 剩下的全部网络结构当做一个特征提取器。
 (2)fine-tuning: 通常来说,直接把预训练模型来用效果不一定足够好,因此需要进行fine-tuning(微调)。fine-tuning需要冻结网络的前几层参数,只更新网络结构的后面几层和最后的全连接层,这样效果会更好。
 (3) Learning rate: 在迁移学习的微调过程中一般不建议使用过大的学习率,通常来说1e-5是比较合适的选择

二、代码

resnet101 官网定义

import torch
from torchvision.models.resnet import ResNet, Bottleneck
 
def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location='cpu')  # 加载模型文件,pt, pth 文件都可以
        model.load_state_dict( checkpoint )
    return model

然后使用resnet101,加载官方预训练模型,再修改最后全连接层,训练过程只对最后全连接层做训练

    #初始化net,训练和验证都需要net
    
    net = resnet101(pretrained=True)
    net.fc = torch.nn.Sequential(torch.nn.Linear(2048, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 257 ))  

    net = net.to(device)   
    
    #初始化optimizer,只有train时使用
    optimizer = optim.SGD( net.fc.parameters(), lr=1e-5, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20 , gamma=0.5) 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185265.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • java反转数组_Java实现数组反转翻转的方法实例

    java反转数组_Java实现数组反转翻转的方法实例数组翻转的方法(java实现),数组翻转,就是将数组倒置,例如原数组为:{“a”,”b”,”c”,”d”},那么翻转后的数组为{“d”,”c”,”b”,”a”}。【方法一】使用集合个工具类:Collections.reverse(ArrayList)将数组进行反转:importjava.util.ArrayList;importjava.util.Collections;publiccl…

    2022年6月6日
    39
  • PHP表单数据写入MySQL代码

    参考:http://www.cnblogs.com/roucheng/p/phpmysql.html上面的代码用不同格式,不知道哪种格式更好

    2021年12月25日
    43
  • leetcode链表问题_c++反转链表

    leetcode链表问题_c++反转链表给你单链表的头指针 head 和两个整数 left 和 right ,其中 left <= right 。请你反转从位置 left 到位置 right 的链表节点,返回 反转后的链表 。示例 1:输入:head = [1,2,3,4,5], left = 2, right = 4输出:[1,4,3,2,5]示例 2:输入:head = [5], left = 1, right = 1输出:[5] 提示:链表中节点数目为 n1 <= n <= 500-500

    2022年8月8日
    4
  • 初识lunix_centos ubuntu

    初识lunix_centos ubuntuLinux常用快捷键    先安装rz指令,再使用rz进行导入文件    ls显示当前目录下的文件  ls-thal显示当前目录下的文件及详细信息  cd切换目录  mkdir新建目录  cp-r旧目录/新目录拷贝文件  rm-r目录删除文件  su账号名使用指定用户登录系统  tar压缩/解压命令    …

    2022年9月28日
    5
  • 77. Combinations「建议收藏」

    77. Combinations「建议收藏」Giventwointegers n and k,returnallpossiblecombinationsof k numbersoutof1… n.Forexample,If n =4and k =2,asolutionis:[[2,4],[3,4],[2,3],[1,2],[1,3],[1,4],]

    2022年10月6日
    2
  • html json数组拼接

    html json数组拼接作为一个菜鸟,自己想的笨办法2333,不过总归能用。。。//先定义一个json对象jsonstr=”[]”;jsonarray=eval(’(’+jsonstr+’)’);//传入两两个参数为格式相同数据不一样的json对象functionappenjson(jsonbject1,jsonbject2){//循环第一个传入的jsonfor(vari=0;i<…

    2022年5月27日
    64

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号