Sklearn中的CV与KFold详解

Sklearn中的CV与KFold详解关于交叉验证,我在之前的文章中已经进行了简单的介绍,而现在我们则通过几个更加详尽的例子.详细的介绍CV%matplotlibinlineimportnumpyasnpfromsklearn.model_selectionimporttrain_test_splitfromsklearnimportdatasetsfromsklearnimports…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

关于交叉验证,我在之前的文章中已经进行了简单的介绍,而现在我们则通过几个更加详尽的例子.详细的介绍

CV

%matplotlib inline
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm

iris = datasets.load_iris()
iris.data.shape,iris.target.shape
((150, 4), (150,))

一般的分割方式,训练集-测试集.然而这种方式并不是很好

X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.4, random_state=0) 

clf_svc = svm.SVC(kernel='linear').fit(X_train,y_train)
clf_svc.score(X_test,y_test)
0.9666666666666667
  • 缺点一:浪费数据
  • 缺点二:容易过拟合,且矫正方式不方便

这时,我们需要使用另外一种分割方式-交叉验证

from sklearn.model_selection import cross_val_score
clf_svc_cv = svm.SVC(kernel='linear',C=1)
scores_clf_svc_cv = cross_val_score(clf_svc_cv,iris.data,iris.target,cv=5)
print(scores_clf_svc_cv)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores_clf_svc_cv.mean(), scores_clf_svc_cv.std() * 2))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]
Accuracy: 0.98 (+/- 0.03)

同时我们也可以为cross_val_score选择不同的性能度量函数

from sklearn import metrics
scores_clf_svc_cv_f1 = cross_val_score(clf_svc_cv,iris.data,iris.target,cv=5,scoring='f1_macro')
print("F1: %0.2f (+/- %0.2f)" % (scores_clf_svc_cv_f1.mean(), scores_clf_svc_cv_f1.std() * 2))
F1: 0.98 (+/- 0.03)

同时也正是这些特性使得,cv与数据转化以及pipline(sklearn中的管道机制)变得更加契合

from sklearn import preprocessing
from sklearn.pipeline import make_pipeline
clf_pipline = make_pipeline(preprocessing.StandardScaler(),svm.SVC(C=1))
scores_pipline_cv = cross_val_score(clf_pipline,iris.data,iris.target,cv=5)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores_clf_svc_cv_f1.mean(), scores_clf_svc_cv_f1.std() * 2))
Accuracy: 0.98 (+/- 0.03)

同时我们还可以在交叉验证使用多个度量函数

from sklearn.model_selection import cross_validate
from sklearn import metrics

scoring = ['precision_macro', 'recall_macro']
clf_cvs = svm.SVC(kernel='linear', C=1, random_state=0)
scores_cvs = cross_validate(clf_cvs,iris.data,iris.target,cv=5,scoring=scoring,return_train_score = False)
sorted(scores_cvs.keys())
['fit_time', 'score_time', 'test_precision_macro', 'test_recall_macro']
print(scores_cvs['test_recall_macro'])
print("test_recall_macro: %0.2f (+/- %0.2f)" % (scores_cvs['test_recall_macro'].mean(), scores_cvs['test_recall_macro'].std() * 2))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]
test_recall_macro: 0.98 (+/- 0.03)

同时cross_validate也可以使用make_scorer自定义度量功能
或者使用单一独量

from sklearn.metrics.scorer import make_scorer
scoring_new = { 
   'prec_macro': 'precision_macro','recall_micro': make_scorer(metrics.recall_score, average='macro')}
# 注意此处的make_scorer
scores_cvs_new = cross_validate(clf_cvs,iris.data,iris.target,cv=5,scoring=scoring_new,return_train_score = False)
sorted(scores_cvs_new.keys())
['fit_time', 'score_time', 'test_prec_macro', 'test_recall_micro']
print(scores_cvs_new['test_recall_micro'])
print("test_recall_micro: %0.2f (+/- %0.2f)" % (scores_cvs_new['test_recall_micro'].mean(), scores_cvs_new['test_recall_micro'].std() * 2))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]
test_recall_micro: 0.98 (+/- 0.03)

关于Sklearn中的CV还有cross_val_predict可用于预测,下面则是Sklearn中一个关于使用该方法进行可视化预测错误的案例

from sklearn import datasets
from sklearn.model_selection import cross_val_predict
from sklearn import linear_model
import matplotlib.pyplot as plt

lr = linear_model.LinearRegression()
boston = datasets.load_boston()
y = boston.target

# cross_val_predict returns an array of the same size as `y` where each entry
# is a prediction obtained by cross validation:
predicted = cross_val_predict(lr, boston.data, y, cv=10)

fig, ax = plt.subplots()
fig.set_size_inches(18.5,10.5)
ax.scatter(y, predicted, edgecolors=(0, 0, 0))
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=4)
ax.set_xlabel('Measured')
ax.set_ylabel('Predicted')
plt.show()

通过交叉验证进行数据错误的可视化

KFlod的例子

Stratified k-fold:实现了分层交叉切分

from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2, 3, 4],
              [11, 12, 13, 14],
              [21, 22, 23, 24],
              [31, 32, 33, 34],
              [41, 42, 43, 44],
              [51, 52, 53, 54],
              [61, 62, 63, 64],
              [71, 72, 73, 74]])

y = np.array([1, 1, 0, 0, 1, 1, 0, 0])

stratified_folder = StratifiedKFold(n_splits=4, random_state=0, shuffle=False)
for train_index, test_index in stratified_folder.split(X, y):
    print("Stratified Train Index:", train_index)
    print("Stratified Test Index:", test_index)
    print("Stratified y_train:", y[train_index])
    print("Stratified y_test:", y[test_index],'\n')
Stratified Train Index: [1 3 4 5 6 7]
Stratified Test Index: [0 2]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] 

Stratified Train Index: [0 2 4 5 6 7]
Stratified Test Index: [1 3]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] 

Stratified Train Index: [0 1 2 3 5 7]
Stratified Test Index: [4 6]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0] 

Stratified Train Index: [0 1 2 3 4 6]
Stratified Test Index: [5 7]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0] 
from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2, 3, 4],
              [11, 12, 13, 14],
              [21, 22, 23, 24],
              [31, 32, 33, 34],
              [41, 42, 43, 44],
              [51, 52, 53, 54],
              [61, 62, 63, 64],
              [71, 72, 73, 74]])

y = np.array([1, 1, 0, 0, 1, 1, 0, 0])

stratified_folder = StratifiedKFold(n_splits=4, random_state=0, shuffle=False)
for train_index, test_index in stratified_folder.split(X, y):
    print("Stratified Train Index:", train_index)
    print("Stratified Test Index:", test_index)
    print("Stratified y_train:", y[train_index])
    print("Stratified y_test:", y[test_index],'\n')
Stratified Train Index: [1 3 4 5 6 7]
Stratified Test Index: [0 2]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] 

Stratified Train Index: [0 2 4 5 6 7]
Stratified Test Index: [1 3]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] 

Stratified Train Index: [0 1 2 3 5 7]
Stratified Test Index: [4 6]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0] 

Stratified Train Index: [0 1 2 3 4 6]
Stratified Test Index: [5 7]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0] 

除了这几种交叉切分KFlod外,还有很多其他的分割方式,比如StratifiedShuffleSplit重复分层KFold,实现了每个K中各类别的比例与原数据集大致一致,而RepeatedStratifiedKFold 可用于在每次重复中用不同的随机化重复分层 K-Fold n 次。至此基本的KFlod在Sklearn中都实现了

注意

i.i.d 数据是机器学习理论中的一个常见假设,在实践中很少成立。如果知道样本是使用时间相关的过程生成的,则使用 time-series aware cross-validation scheme 更安全。 同样,如果我们知道生成过程具有 group structure (群体结构)(从不同 subjects(主体) , experiments(实验), measurement devices (测量设备)收集的样本),则使用 group-wise cross-validation 更安全。

下面就是一个分组KFold的例子,

from sklearn.model_selection import GroupKFold

X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]

gkf = GroupKFold(n_splits=3)
for train, test in gkf.split(X, y, groups=groups):
    print("%s %s" % (train, test))
[0 1 2 3 4 5] [6 7 8 9]
[0 1 2 6 7 8 9] [3 4 5]
[3 4 5 6 7 8 9] [0 1 2]

更多内容请参考:sklearn相应手册

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/191284.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • js数组反转的几种方法「建议收藏」

    js数组反转的几种方法「建议收藏」第一种:创建一个新数组使用reverse()的方法进行反转。letarr=[1,2,3,4]letarr1=arr.reverse()console.log(arr1);//[4,3,2,1]第二种:利用数组循环,使用unshift()方法将新项添加到数组的开头,并返回新的长度。unshift()方法会改变数组的长度。letarr2=[1,2,3,4]letarr3=[]arr2.forEach((element)=>{..

    2022年6月1日
    134
  • 大文件或目录复制时的信息统计脚本

    大文件或目录复制时的信息统计脚本

    2021年8月26日
    47
  • MyBatis核心组件之SqlSessionFactory

    MyBatis核心组件之SqlSessionFactoryMyBatis的核心组件MyBatis的核心组件分为4个部分:SqlSessionFactoryBuilder(构造器):它会根据配置或者代码来生成SqlSessionFactory,采用的是分布构建的Builder模式。SqlSessionFactory(工厂接口):依靠它来生成SqlSession,使用的是工厂模式。SqlSession(会话):一个既可以发送SQL执行返回结果,也可…

    2022年5月22日
    26
  • java后端简历项目经历_简历上的项目经历怎么写 ?这 3 条原则不可忽视 !

    java后端简历项目经历_简历上的项目经历怎么写 ?这 3 条原则不可忽视 !阅读本文大概需要 5 分钟 作者 黄小斜作为一个程序员 想必大家曾经都做过一些项目 可能现在手头上也还有一些项目 不过还是有很多学生朋友来问我 没有项目怎么办 诚然 确实有不少同学没有实习经历 又没有什么像样的项目经历 对于这样的同学 简历上的项目经历难道只能空着了吗 其实不然 就算你是跟着一些课程做项目 你也可以通过丰富项目内容的方法把项目变成自己的 只要你真的去做了 真的理解了代码逻辑 同时

    2025年12月3日
    2
  • Windows 7 连接 Windows 10 共享打印机,Windows 无法连接打印机,操作失败,错误为0x0000011b 的终极解决办法

    Windows 7 连接 Windows 10 共享打印机,Windows 无法连接打印机,操作失败,错误为0x0000011b 的终极解决办法Windows7连接Windows10共享打印机出现错误0x000001b,无法通过卸载KB5005565安全更新来解决该问题,正确的处理方法是手工添加一个本地打印机,本方法稳定可靠。本文详述了该方法的操作步骤。

    2025年10月22日
    5
  • ios最新屏蔽更新描述文件(屏蔽ios14系统升级描述文件)

    崇尚专注乐于分享愿为您带来生活中的便利每晚8点期待您的到来捧场目前我们的公众号已经为上万人提供了帮助,新来的小伙伴,如果你不想错过每一期的资源,需要获取往期分享的资源,可以在公众号菜单栏找到“资源汇总”即可获取到全部资源。公众号内所有资源皆为免费分享,不收取任何费用.公众号内资源大部分来源于网络,不保证永久有效.资源都经过小编实机测试,但不保证兼容所有机型.公众号经过一系列改版现在推送…

    2022年4月16日
    60

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号