pytorch(8)– resnet101 迁移学习记录

pytorch(8)– resnet101 迁移学习记录一、前言本篇记录使用pytorch官方resnet101实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training+fine-tuning(预训练+调参)的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步(1)把预训练模型当做特征提取器:TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

一、前言

    本篇记录使用 pytorch 官方 resnet101 实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training + fine-tuning(预训练+调参) 的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步

 (1)把预训练模型当做特征提取器: TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或者把最后一层直接去掉换成自己的分类器, 剩下的全部网络结构当做一个特征提取器。
 (2)fine-tuning: 通常来说,直接把预训练模型来用效果不一定足够好,因此需要进行fine-tuning(微调)。fine-tuning需要冻结网络的前几层参数,只更新网络结构的后面几层和最后的全连接层,这样效果会更好。
 (3) Learning rate: 在迁移学习的微调过程中一般不建议使用过大的学习率,通常来说1e-5是比较合适的选择

二、代码

resnet101 官网定义

import torch
from torchvision.models.resnet import ResNet, Bottleneck
 
def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location='cpu')  # 加载模型文件,pt, pth 文件都可以
        model.load_state_dict( checkpoint )
    return model

然后使用resnet101,加载官方预训练模型,再修改最后全连接层,训练过程只对最后全连接层做训练

    #初始化net,训练和验证都需要net
    
    net = resnet101(pretrained=True)
    net.fc = torch.nn.Sequential(torch.nn.Linear(2048, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 257 ))  

    net = net.to(device)   
    
    #初始化optimizer,只有train时使用
    optimizer = optim.SGD( net.fc.parameters(), lr=1e-5, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20 , gamma=0.5) 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185265.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 服务器风扇端子型号,出几样物品-相机连接头,服务器风扇,滤波器,接线端子等等如图…

    服务器风扇端子型号,出几样物品-相机连接头,服务器风扇,滤波器,接线端子等等如图…1名称:CCD12针相机连接头,二手拆机件成色如图!线长3厘米,重量约21克。8元/只http://img02.taobaocdn.com/imgextra/i2/119912523/T2642jXsNXXXXXXXXX_!!119912523.jpg2名称:服务器风扇,品牌:三洋SANYO,型号:SANACE409CRA0412P5G05,12V1.0A,长40mm*宽40mm*高56mm,…

    2022年6月22日
    45
  • Pycharm专业版注册激活

    Pycharm专业版注册激活快去这个链接:http://blog.csdn.net/lanchunhui/article/details/51660951http://idea.lanyus.com/

    2022年8月26日
    3
  • zigbee协议栈工作流程 From zigbee菜鸟笔记(十 一)

    zigbee协议栈工作流程 From zigbee菜鸟笔记(十 一)一.ZigBee协议栈简介什么是ZigBee协议栈呢?它和ZigBee协议有什么关系呢?协议是一系列的通信标准,通信双方需要共同按照这一标准进行正常的数据发射和接收。协议栈是协议的具体实现形式,通俗点来理解就是协议栈是协议和用户之间的一个接口,开发人员通过使用协议栈来使用个协议的,进而实现无线数据收发。ZigBee的协议分为两部分,IEEE802.15.4定义了PHY(物理层)和MAC(介质访问层)技术规范;ZigBee联盟定义了NWK(网络层)、APS(应用程序支持子层)、APL(应用层

    2022年5月28日
    43
  • clgao资源网址

    clgao资源网址工作流项目github:https://github.com/snakerflow/snaker-web

    2022年6月24日
    25
  • 超人学院Hadoop大数据资源共享

    超人学院Hadoop大数据资源共享

    2022年1月15日
    46
  • getopt在Python中的使用

    getopt在Python中的使用在运行程序时,可能需要根据不同的条件,输入不同的命令行选项来实现不同的功能。目前有短选项和长选项两种格式。短选项格式为”-“加上单个字母选项;长选项为”–“加上一个单词。长格式是在Linux下引入的。许多Linux程序都支持这两种格式。在Python中提供了getopt模块很好的实现了对这两种用法的支持,而且使用简单。取得命令行参数  在使用之前,首先要取得命令行参数。使用sys模块

    2022年4月30日
    41

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号