kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。于是我们会给出一系列的最大深度的值,比如{‘max_depth’:[1,2,3,4,5]},我们会尽可能包含最优最大深度。不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型。

下面提供一个简单的利用决策树预测乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_split

from sklearn.metrics import make_scorer, accuracy_score

from sklearn.tree import DecisionTreeClassifier

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(

data[‘data’], data[‘target’], train_size=0.8, random_state=0)

regressor = DecisionTreeClassifier(random_state=0)

parameters = {
‘max_depth’: range(1, 6)}

scoring_fnc = make_scorer(accuracy_score)

kfold = KFold(n_splits=10)

grid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)

grid = grid.fit(X_train, y_train)

reg = grid.best_estimator_

print(‘best score: %f%grid.best_score_)

print(‘best parameters:’)

for key in parameters.keys():

print(%s: %d%(key, reg.get_params()[key]))

print(‘test score: %f%reg.score(X_test, y_test))

import pandas as pd

pd.DataFrame(grid.cv_results_).T

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462

best parameters:

max_depth: 4

test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/191614.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 加载本地cifar10 数据集

    加载本地cifar10 数据集defload_CIFAR10(ROOT):”””loadallofcifar”””xs=[]ys=[]forbinrange(1,6):f=os.path.join(ROOT,’data_batch_%d’%(b,))X,Y=load_CIFAR_batch(f)xs.append(X)y…

    2022年6月22日
    31
  • 机器学习中学习曲线的 bias vs variance 以及 数据量m

    机器学习中学习曲线的 bias vs variance 以及 数据量m

    2021年11月21日
    43
  • H2数据库教程_h2数据库编辑数据库

    H2数据库教程_h2数据库编辑数据库 启动和使用H2控制台H2控制台应用程序允许您使用浏览器访问数据库。这可以是H2数据库,也可以是支持JDBCAPI的其他数据库。这是一个客户端/服务器应用程序,因此需要服务器和客户端(浏览器)来运行它。根据您的平台和环境,有多种方法可以启动H2控制台:OS 开始 视窗 单击[开始],[所有程序],[H2]和[H2控制台(命令行)]  系统托盘中将添加…

    2022年10月12日
    7
  • php 中更简洁的三元运算符 ?:

    php 中更简洁的三元运算符 ?:

    2021年10月24日
    50
  • 【每天一个 Linux 命令】tree命令

    【每天一个 Linux 命令】tree命令1.前言本文主要讲解Linux系统上的tree命令的详细使用方法。tree命令是一个小型的跨平台命令行程序,用于递归地以树状格式列出或显示目录的内容。它输出每个子目录中的目录路径和文件,以及子目录和文件总数的摘要。tree程序可以在Unix和类Unix系统(如Linux)中使用,也可以在DOS、Windows和许多其他操作系统中使用。它为输出操作提供了各种选项,从文件选项、排序选项到图形选项,并支持XML、JSON和HTML格式的输出。在这篇教程中,我们将通过使用案例演示如何使用tree命令递归

    2022年7月24日
    11
  • 智能小车设计规划_智能循迹避障小车设计

    智能小车设计规划_智能循迹避障小车设计摘要该课题主要基于单片机的循迹、避障、WiFi、蓝牙等功能的智能小车,在一些特殊环境下有着特殊的意义。硬件控制以arduino为控制核心。采用超声波避障和红外避障传感器共同完成寻迹、避障功能,并将相关信号传送给单片机,经单片机控制系统分析判断后控制驱动芯片驱动直流电机实现小车前进、后退、左转、右转,停止。软件采用移植性较好的c语言编写,通过手机蓝牙App实现对智能小车的控制。通过TCP/UD协…

    2022年10月18日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号