MATLAB对Googlenet模型进行迁移学习

MATLAB对Googlenet模型进行迁移学习调用MATLAB中的Googlenet工具箱进行迁移学习。%%加载数据clc;closeall;clear;Location=”;%这里输入自己的数据集unzip(‘MerchData.zip’);imds=imageDatastore(‘MerchData’,…%若使用自己的数据集则改为Location(不加单引号)…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

调用MATLAB中的Googlenet工具箱进行迁移学习

%% 加载数据
clc;close all;clear;
Location = '';%这里输入自己的数据集
unzip('MerchData.zip');
imds = imageDatastore('MerchData',...  %若使用自己的数据集则改为Location(不加单引号)
                       'IncludeSubfolders',true,...
                       'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');%将数据集按7:3的比例分为训练集和测试集
%% 加载预训练网络
net = googlenet;
%% 从训练有素的网络中提取图层,并绘制图层图
lgraph = layerGraph(net);%从训练网络中提取layer graph

%绘制layer graph
% figure('Units','normalize','Position',[0.1 0.1 0.8 0.8]);
% plot(lgraph)
% net.Layers(1)

inputSize = net.Layers(1).InputSize;


%% 替换最终图层
% 为了训练Googlenet去分类新的图像,取代网络的最后三层。这三层为'loss3-classifier', 'prob', 和
% 'output',包含如何将网络的提取的功能组合为类概率和标签的信息。在层次图中添加三层新层: a fully connected layer, a softmax layer, and a classification output layer
% 将全连接层设置为同新的数据集中类的数目相同的大小,为了使新层比传输层学习更快,增加全连接层的学习因子。
lgraph = removeLayers(lgraph,{'loss3-classifier','prob','output'});
numClasses = numel(categories(imdsTrain.Labels));
newLayers = [
              fullyConnectedLayer(numClasses,'Name','fc','weightLearnRateFactor',10,'BiasLearnRateFactor',10)
              softmaxLayer('Name','softmax')
              classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);

%将网络中最后一个传输层(pool5-drop_7x7_s1)连接到新层
lgraph = connectLayers(lgraph,'pool5-drop_7x7_s1','fc');

% 绘制新的图层
% figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
% plot(lgraph)
% ylim([0,10])
 
 %% 冻结初始图层
 % 这个网络现在已经准备好训练新的图像集。或者你可以通过设置这些层的学习速率为0来“冻结”网络中早期层的权重
 %在训练过程中trainNetwork不会跟新冻结层的参数,因为冻结层的梯度不需要计算,冻结大多数初始层的权重对网络训练加速很重要。
 %如果新的数据集很小,冻结早期网络层也可以防止新的数据集过拟合。

 layers = lgraph.Layers;
 connections = lgraph.Connections;  
 layers(1:110) = freezeWeights(layers(1:110));%调用freezeWeights函数,设置开始的110层学习速率为0
 lgraph = createLgraphUsingConnections(layers,connections);%调用createLgraphUsingConnections函数,按原始顺序重新连接所有的层。


%% 训练网络
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter(...
                                    'RandXReflection',true,...
                                    'RandXTranslation',pixelRange,...
                                    'RandYTranslation',pixelRange);
%对输入数据进行数据加强
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);
 %  自动调整验证图像大小
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
 %设置训练参数
options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',3, ...  %设置验证频率
    'ValidationPatience',Inf, ...
    'Verbose',true ,...
    'Plots','training-progress');
 %开始训练网络
googlenetTrain = trainNetwork(augimdsTrain,lgraph,options);
 
 
%% 对验证图像进行分类
[YPred,probs] = classify(googlenetTrain,augimdsValidation);%使用训练好的网络进行分类
accuracy = mean(YPred == imdsValidation.Labels)%计算网络的精确度

%% 保存训练好的模型
 save googlenet_03 googlenetTrain;
% save  x  y;  保存训练好的模型y(注意:y为训练的模型,即y = trainNetwork()),取名为x

使用训练好的模型进行图像分类
我这里训练的模型是对细胞显微图像进行分类,包括BYST,GRAN,HYAL,MUCS,RBC,WBC,WBCC七种细胞。

%% 加载模型
clc;close all;clear;
load('-mat','E:\MATLAB_Code\googlenet_1');
%% 加载测试集
Location = 'E:\image_test\test_02';
imds = imageDatastore(Location,'includeSubfolders',true,'LabelSource','foldernames');
inputSize = googlenetTrain.Layers(1).InputSize; 
imdstest = augmentedImageDatastore(inputSize(1:2),imds);
tic;
YPred = classify(googlenetTrain,imdstest);
%使用训练好的模型对测试集进行分类
disp(['分类所用时间为:',num2str(toc),'秒']);
%% 显示分类结果,绘制混淆矩阵
byst = 'BYST';
BYST = numel(YPred,YPred == byst);
disp(['BYST = ',num2str(BYST)]);
gran = 'GRAN';
GRAN = numel(YPred,YPred == gran);
disp(['GRAN = ',num2str(GRAN)]);
hyal = 'HYAL';
HYAL = numel(YPred,YPred == hyal);
disp(['HYAL = ',num2str(HYAL)]);
mucs = 'MUCS';
MUCS = numel(YPred,YPred == mucs);
disp(['MUCS = ',num2str(MUCS)]);
rbc = 'RBC';
RBC = numel(YPred,YPred == rbc);
disp(['RBC = ',num2str(RBC)]);
wbc = 'WBC';
WBC = numel(YPred,YPred == wbc);
disp(['WBC = ',num2str(WBC)]);
wbcc = 'WBCC';
WBCC = numel(YPred,YPred == wbcc);
disp(['WBCC = ',num2str(WBCC)]);
sum = numel(YPred);
disp(['sum = ',num2str(sum)]);
%求出每个标签对应的分类数量
% numel(A)  返回数组A的数目
% numel(A,x) 返回数组A在x的条件下的数目
%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%计算精确度%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 YTest = imds.Labels;
 accuracy = mean(YPred == YTest);
 disp(['accuracy = ',num2str(accuracy)]);
 % disp(x)   变量x的值
 % num2str(x)  将数值数值转换为表示数字的字符数组
 %%
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%随机显示测试分类后的图片%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
idx = randperm(numel(imds.Files),16);
figure
for i = 1:16
    subplot(4,4,i);
    I = readimage(imds,idx(i));
    imshow(I);
    label = YPred(idx(i));
    title(string(label));
end

%% 绘制混淆矩阵
predictLabel = YPred;%通过训练好的模型分类后的标签
actualLabel = YTest;%原始的标签
plotconfusion(actualLabel,predictLabel,'Googlenet');%绘制混淆矩阵


%    plotconfusion(targets,outputs);绘制混淆矩阵,使用target(true)和output(predict)标签,将标签指定为分类向量或1到N的形式
%% 保存分类后的图片
x = numel(imds.Files);
% 图片保存位置
Location_BYST = 'E:\image_classification\Googlenet\BYST';
Location_GRAN = 'E:\image_classification\Googlenet\GRAN';
Location_HYAL = 'E:\image_classification\Googlenet\HYAL';
Location_MUCS = 'E:\image_classification\Googlenet\MUCS';
Location_RBC  = 'E:\image_classification\Googlenet\RBC';
Location_WBC  = 'E:\image_classification\Googlenet\WBC';
Location_WBCC = 'E:\image_classification\Googlenet\WBCC';
writePostfix = '.bmp';%图片保存后缀
for i = 1:x
    I = readimage(imds,i);
    Label = YPred(i);
    Name = YTest(i);
   switch Label
       case 'BYST'
           saveName = sprintf('%s%s%s_%d',Location_BYST,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'GRAN'
           saveName = sprintf('%s%s%s_%d',Location_GRAN,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'HYAL'
           saveName = sprintf('%s%s%s_%d',Location_HYAL,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'MUCS'
           saveName = sprintf('%s%s%s_%d',Location_MUCS,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'RBC'
           saveName = sprintf('%s%s%s_%d',Location_RBC,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'WBC'
           saveName = sprintf('%s%s%s_%d',Location_WBC,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'WBCC'
           saveName = sprintf('%s%s%s_%d',Location_WBCC,'\',Name,i,writePostfix);
           imwrite(I,saveName);
   end
    
    
end

结果:
在这里插入图片描述

在这里插入图片描述

附录
freezeWeights函数

function layers = freezeWeights(layers)

for ii = 1:size(layers,1)
    props = properties(layers(ii));
    for p = 1:numel(props)
        propName = props{p};
        if ~isempty(regexp(propName,'LearnRateFactor$','once'))
            layers(ii).(propName) = 0;
        end
    end
end
end

createLgraphUsingConnections函数

function lgraph = createLgraphUsingConnections(layers,connections)

lgraph = layerGraph();

for i = 1:numel(layers)
    lgraph = addLayers(lgraph,layers(i));
end
    

for c = 1:size(connections,1)
    lgraph = connectLayers(lgraph,connections.Source{c},connections.Destination{c});
end
end
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/170455.html原文链接:https://javaforall.net

(0)
上一篇 2022年8月14日 下午10:16
下一篇 2022年8月14日 下午10:16


相关推荐

  • 便携AI聚合API ChatGPT 4o生图功能使用教程

    便携AI聚合API ChatGPT 4o生图功能使用教程

    2026年3月16日
    3
  • Python—数字推盘游戏设计

    Python—数字推盘游戏设计目标 了解 pygame 模块的框架与基础函数 熟悉 MVC 设计模式 掌握自顶向下的程序设计方式 内容 完成数字推盘游戏设计步骤 代码如下 importpygame localsimport 定义常量 WINWIDTH 640 窗口宽度 WINHEIGHT 480 窗口高度 ROW 3COL 3BLANK None

    2026年3月26日
    2
  • JAVA反射机制

    JAVA反射机制

    2021年11月15日
    46
  • nmap命令教程详解

    nmap命令教程详解-sP:ping扫描(不进行端口扫描)-sT:进行TCP全连接扫描-sS:进行SYN半连接扫描-sF:进行FIN扫描-sN:进行Null扫描-sX:进行Xmas扫描-O:进行测探目标主机版本(不是很准)-sV:可以显示服务的详细版本-A:全面扫描-p:指定端口扫描-oN:会将扫描出来的结果保存成一个txt文件-oX:会将扫描出来的结果保存成一个xml文件[-T1]-[-T5]:提高扫描速度.详细分析1)、主机发现nmap-sP192.168.1

    2022年5月28日
    52
  • navicat premium12 mac 激活码【中文破解版】

    (navicat premium12 mac 激活码)好多小伙伴总是说激活码老是失效,太麻烦,关注/收藏全栈君太难教程,2021永久激活的方法等着你。IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.net/100143.htmlS32PGH0SQB-eyJsaWNlbnNlSWQi…

    2022年3月26日
    765
  • java linkhashset_java中集合怎么定义

    java linkhashset_java中集合怎么定义LinkedHashSet是Set集合的一个实现,具有set集合不重复的特点,同时具有可预测的迭代顺序,也就是我们插入的顺序。并且linkedHashSet是一个非线程安全的集合。如果有多个线程同时访问当前linkedhashset集合容器,并且有一个线程对当前容器中的元素做了修改,那么必须要在外部实现同步保证数据的冥等性。下面我们new一个新的LinkedHashSet容器看一下具体的源码实现。…

    2022年10月12日
    6

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号