kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。于是我们会给出一系列的最大深度的值,比如{‘max_depth’:[1,2,3,4,5]},我们会尽可能包含最优最大深度。不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型。

下面提供一个简单的利用决策树预测乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_split

from sklearn.metrics import make_scorer, accuracy_score

from sklearn.tree import DecisionTreeClassifier

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(

data[‘data’], data[‘target’], train_size=0.8, random_state=0)

regressor = DecisionTreeClassifier(random_state=0)

parameters = {
‘max_depth’: range(1, 6)}

scoring_fnc = make_scorer(accuracy_score)

kfold = KFold(n_splits=10)

grid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)

grid = grid.fit(X_train, y_train)

reg = grid.best_estimator_

print(‘best score: %f%grid.best_score_)

print(‘best parameters:’)

for key in parameters.keys():

print(%s: %d%(key, reg.get_params()[key]))

print(‘test score: %f%reg.score(X_test, y_test))

import pandas as pd

pd.DataFrame(grid.cv_results_).T

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462

best parameters:

max_depth: 4

test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/191614.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 深入理解MySQL索引原理和实现——为什么索引可以加速查询?

    深入理解MySQL索引原理和实现——为什么索引可以加速查询?说到索引,很多人都知道“索引是一个排序的列表,在这个列表中存储着索引的值和包含这个值的数据所在行的物理地址,在数据十分庞大的时候,索引可以大大加快查询的速度,这是因为使用索引后可以不用扫描全表来定位某行的数据,而是先通过索引表找到该行数据对应的物理地址然后访问相应的数据。”但是索引是怎么实现的呢?因为索引并不是关系模型的组成部分,因此不同的DBMS有不同的实现,我们针对MySQL数据库的实现进…

    2022年6月24日
    39
  • Windows 家族的十二种常用密码破解法

    Windows 家族的十二种常用密码破解法

    2021年8月7日
    77
  • GIT在Linux上的安装和使用简介

    GIT在Linux上的安装和使用简介

    2021年9月8日
    69
  • C++ 中的容器类详解

    C++ 中的容器类详解C++中的容器类包括“顺序存储结构”和“关联存储结构”,前者包括vector,list,deque等;后者包括set,map,multiset,multimap等。若需要存储的元素数在编译器间就可以确定,可以使用数组来存储,否则,就需要用到容器类了。1、vector    连续存储结构,每个元素在内存上是连续的;    支持高效的随机访问和在尾端插入/删除操作,但其他位置的插入/删除操…

    2022年9月13日
    0
  • ubuntu ll命令[通俗易懂]

    ubuntu ll命令[通俗易懂]用过Redhat的朋友应该很熟悉ll这个命令,就相当于ls-l,但在Ubuntu中就不行了。严格来说ll不是一个命令,只是命令的别名而已。很多Linux用户都使用bashshell,对普通用户来说用得最多的就是命令补全(按tab键)和alias(别名)功能。Ubuntu默认建立的用户都用的bashshell,所以它也支持别名功能,我们只需要gedi

    2022年9月24日
    0
  • springboot实现查询手机号归属地

    springboot实现查询手机号归属地我也不知道咋写的,测试过了,反正能用就行;packagecom.example.needs.util;importorg.apache.http.HttpEntity;importorg.apache.http.ParseException;importorg.apache.http.client.methods.CloseableHttpResponse;importorg.apache.http.client.methods.HttpGet;importorg.apache.h

    2022年7月22日
    6

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号