kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。于是我们会给出一系列的最大深度的值,比如{‘max_depth’:[1,2,3,4,5]},我们会尽可能包含最优最大深度。不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型。

下面提供一个简单的利用决策树预测乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_split

from sklearn.metrics import make_scorer, accuracy_score

from sklearn.tree import DecisionTreeClassifier

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(

data[‘data’], data[‘target’], train_size=0.8, random_state=0)

regressor = DecisionTreeClassifier(random_state=0)

parameters = {
‘max_depth’: range(1, 6)}

scoring_fnc = make_scorer(accuracy_score)

kfold = KFold(n_splits=10)

grid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)

grid = grid.fit(X_train, y_train)

reg = grid.best_estimator_

print(‘best score: %f%grid.best_score_)

print(‘best parameters:’)

for key in parameters.keys():

print(%s: %d%(key, reg.get_params()[key]))

print(‘test score: %f%reg.score(X_test, y_test))

import pandas as pd

pd.DataFrame(grid.cv_results_).T

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462

best parameters:

max_depth: 4

test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/191614.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 中国天气网-天气预报接口api

    中国天气网-天气预报接口api

    2021年9月25日
    51
  • 如何修改bt tracker服务器,bt tracker服务器

    如何修改bt tracker服务器,bt tracker服务器bttracker服务器内容精选换一换云审计服务支持删除已创建的追踪器。删除追踪器对已有的操作记录没有影响,当您重新开通云审计服务后,依旧可以查看已有的操作记录。DELETE/v1.0/{project_id}/tracker无无无请参见错误码。本文操作介绍使用Linux操作系统的HECS(云耀云服务器)安装宝塔面板。宝塔面板是一款使用方便、功能强大且终身免费的服务器管理软件,支持Linux…

    2022年6月29日
    29
  • ListView 排序问题[通俗易懂]

    ListView 排序问题[通俗易懂] 在DataGrid 中有自带的排序,但是在ListeView中却没有这一项,下面就给出我平时用的ListView的排序使用方法,给志同道合的朋友们参考参考:在ColumnClick事件下添加:其中要注意:intsortColumn=-1;其中sortColumn用来记录上次排序的列的索引privatevoidlistView1_ColumnClick(objectse

    2022年10月3日
    0
  • 死锁与递归锁及信号量等[通俗易懂]

    死锁与递归锁进程也是有死锁的所谓死锁:是指两个或两个以上的进程或线程在执行过程中,因争夺资源而造成的一种互相等待的现象,若无外力作用,它们都将无法推进下去。此时称系统处于死锁状态或系统产生了死

    2022年3月29日
    34
  • linux基本命令_Git常用命令

    linux基本命令_Git常用命令用一个记一个。修改文件权限:chmod777文件名文件夹及其子文件:chmod-R777文件夹

    2022年10月3日
    0
  • linux中的ldd命令简介

    linux中的ldd命令简介在linux中,有些命令是大家通用的,比如ls,rm,mv,cp等等,这些我觉得没有必要再细说了。而有些命令,只有开发人员才会用到的,这类命令,作为程序员的我们,是有必要了解的,有的甚至需要熟练使用。有的人总说,这些命令不重要,用的时候去查就行了,这么多么扯淡的说法啊。具体用法细节是可以可查,但至少得知道有ldd这个东西吧。连ldd都不知道,怎么知道ldd是干啥的呢?

    2022年4月28日
    74

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号