pytorch(8)– resnet101 迁移学习记录

pytorch(8)– resnet101 迁移学习记录一、前言本篇记录使用pytorch官方resnet101实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training+fine-tuning(预训练+调参)的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步(1)把预训练模型当做特征提取器:TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

一、前言

    本篇记录使用 pytorch 官方 resnet101 实现迁移学习,迁移学习是当前深度学习领域的一系列通用的解决方案,而不是一个具体的算法模型。Pre-training + fine-tuning(预训练+调参) 的迁移学习方式是现在深度学习中一个非常流行的迁移学习方式,有以下3步

 (1)把预训练模型当做特征提取器: TensorFlow或者Pytorch都有ImageNet上预训练好的模型,将最后一层全连接层(原始的是1000个类别或者更多)改成你自己的分类任务的种类进行输出,或者把最后一层直接去掉换成自己的分类器, 剩下的全部网络结构当做一个特征提取器。
 (2)fine-tuning: 通常来说,直接把预训练模型来用效果不一定足够好,因此需要进行fine-tuning(微调)。fine-tuning需要冻结网络的前几层参数,只更新网络结构的后面几层和最后的全连接层,这样效果会更好。
 (3) Learning rate: 在迁移学习的微调过程中一般不建议使用过大的学习率,通常来说1e-5是比较合适的选择

二、代码

resnet101 官网定义

import torch
from torchvision.models.resnet import ResNet, Bottleneck
 
def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location='cpu')  # 加载模型文件,pt, pth 文件都可以
        model.load_state_dict( checkpoint )
    return model

然后使用resnet101,加载官方预训练模型,再修改最后全连接层,训练过程只对最后全连接层做训练

    #初始化net,训练和验证都需要net
    
    net = resnet101(pretrained=True)
    net.fc = torch.nn.Sequential(torch.nn.Linear(2048, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 1024),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(1024, 257 ))  

    net = net.to(device)   
    
    #初始化optimizer,只有train时使用
    optimizer = optim.SGD( net.fc.parameters(), lr=1e-5, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20 , gamma=0.5) 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/185265.html原文链接:https://javaforall.net

(0)
上一篇 2022年10月6日 下午4:16
下一篇 2022年10月6日 下午4:16


相关推荐

  • 算法交易:华尔街怪兽的核武器

    算法交易:华尔街怪兽的核武器1980年华尔街的黑客生涯:天时地利20世纪70年代末期,算法开始进入人们的工作,这一趋势席卷了世界各地的金融市场,标志着华尔街黑客时代已然来临。华尔街逐渐吸引了美国越来越多杰出的数学家和科学家投身于编写交易算法的工作。在布莱克?斯科尔斯统治市场之前,已经有少数工程师和科学家进入曼哈顿下城市场了,但他们大都是外来移民。麻省理工、哈佛和此类高等学府的工程楼和科学楼成了招聘者竞相争夺人才…

    2022年7月11日
    17
  • linux dstat,dstat 用法详解

    linux dstat,dstat 用法详解Windows 下有性能监视器 Linux 下当然也不示弱 亲还在用 vmstat iostat nfsstat netstat ifstat 来查看系统性能状态 那你就弱爆了 今天给亲一个神器 只需他一个你就可以得到以上这么多工具综合的功能 闲言表过 步入正题 dstat 如果系统没有些工具 yum yinstalldsta 安装下即妥 此软件小巧玲珑 软件包大小只有 144k 安装

    2025年9月13日
    7
  • 【C】使用backtrace获取堆栈信息

    【C】使用backtrace获取堆栈信息1 backtrace 一些内存检测工具如 Valgrind 调试工具如 GDB 可以查看程序运行时函数调用的堆栈信息 有时候在分析程序时要获得堆栈信息 借助于 backtrace 是很有帮助的 其原型如下 includeexeci hintbacktrac voidbuffer intsize charbacktrac sy

    2026年1月30日
    2
  • 智慧工地 安全帽识别系统

    智慧工地 安全帽识别系统随着时代的发展科技也越来越发达,近些年来建筑行业在我国怦然兴起,关于建筑方面的隐患也日益增加,为此北京富维图像公司研发了一款智慧工地安全帽识别这项技术,可以有效的避免一些事故的发生。智慧工地安全帽识别这项技术它能有效识别作业过程中突发的危险以及预防危险的到临,以下是关于安全帽识别系统的详细介绍。第一北京富维图像公司对智慧工地安全帽识别系统采用人脸识别功能,可360°识别工作人员是否佩戴安全帽,可以通过前段摄像机进…

    2022年5月12日
    59
  • SSL证书安装指引

    SSL证书安装指引

    2021年10月14日
    50
  • IDataParameter调用存储过程

    IDataParameter调用存储过程

    2021年12月6日
    56

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号