kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。于是我们会给出一系列的最大深度的值,比如{‘max_depth’:[1,2,3,4,5]},我们会尽可能包含最优最大深度。不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型。

下面提供一个简单的利用决策树预测乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_split

from sklearn.metrics import make_scorer, accuracy_score

from sklearn.tree import DecisionTreeClassifier

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(

data[‘data’], data[‘target’], train_size=0.8, random_state=0)

regressor = DecisionTreeClassifier(random_state=0)

parameters = {
‘max_depth’: range(1, 6)}

scoring_fnc = make_scorer(accuracy_score)

kfold = KFold(n_splits=10)

grid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)

grid = grid.fit(X_train, y_train)

reg = grid.best_estimator_

print(‘best score: %f%grid.best_score_)

print(‘best parameters:’)

for key in parameters.keys():

print(%s: %d%(key, reg.get_params()[key]))

print(‘test score: %f%reg.score(X_test, y_test))

import pandas as pd

pd.DataFrame(grid.cv_results_).T

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462

best parameters:

max_depth: 4

test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/191614.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • idea2021.9激活码步骤【2021.8最新】

    (idea2021.9激活码步骤)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

    2022年3月27日
    46
  • Django(50)drf异常模块源码分析

    Django(50)drf异常模块源码分析异常模块源码入口APIView类中dispatch方法中的:response=self.handle_exception(exc)源码分析我们点击handle_exception跳转,查看该

    2022年8月7日
    6
  • 查看服务器的外网地址[通俗易懂]

    查看服务器的外网地址[通俗易懂]服务器上执行以下命令:curlmembers.3322.org/dyndns/getip转载于:https://blog.51cto.com/zhenfen/2106824

    2022年5月1日
    145
  • java集合类面试题_Java集合类相关面试题

    java集合类面试题_Java集合类相关面试题1、Collection和Collections的差别java.util.Collection是一个集合接口,Collection接口在Java类库中有非常多详细的实现。比如List、Setjava.util.Collections是针对集合类的一个帮助类,它提供了一系列的静态方法实现对各种集合的搜索、排序、线程安全化等操作。2、ArrayList与Vector的差别这两个类都实现了List接…

    2022年7月7日
    21
  • 服务器的文件不能修改器,荒野行动gg修改器脚本安装文件运行出错「建议收藏」

    服务器的文件不能修改器,荒野行动gg修改器脚本安装文件运行出错「建议收藏」PrivateDeclareFunctionCreateDirectoryLib”kernel32″Alias”CreateDirectoryA”(ByVallpPathNameAsString,lpSecurityAttributesAsSECURITY_ATTRIBUTES)AsLongPrivateTypeSECURITY_ATTRIBUTESnLengthAsLonglpSecurityD…

    2025年9月15日
    7
  • 电赛练习之旋转倒立摆

    电赛练习之旋转倒立摆2019年电赛已经结束,虽然结果不能令人满意,但闲下来,还是总结一下电赛学到的东西与失败的地方。这一次先来谈一下一阶旋转倒立摆。一、题目分析:拿到一道题目,其实最应该做的事情是分析题目,因为我们往往可以发现某些发挥题是在基础题的基础上进行的,但是,可能某些发挥题需要在基础题的基础上修改结构,我们也可以发现,题目中的某些问题具有相似性,当我们合并同类项的时候,可以把题目的要求变得简单。一下,我粘…

    2022年8月18日
    5

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号