Keras入门(八)K折交叉验证

Keras入门(八)K折交叉验证在文章 Keras 入门 一 搭建深度神经网络 DNN 解决多分类问题中 笔者介绍了如何搭建 DNN 模型来解决 IRIS 数据集的多分类问题 本文将在此基础上介绍如何在 Keras 中实现 K 折交叉验证 什么是 K 折交叉验证 K 折交叉验证是机器学习中的一个专业术语 它指的是将原始数据随机分成 K 份 每次选择 K 1 份作为训练集 剩余的 1 份作为测试集 交叉验证重复 K 次 取 K 次准确率的平均值作为最终模型的评价指标 一般取 K 10 即 10 折交叉验证 如下图所示 用交叉验证的目的是为了得到可靠稳定的模型 K 折交

什么是K折交叉验证?

Keras实现K折交叉验证

# -*- coding: utf-8 -*- # model_train.py # Python 3.6.8, TensorFlow 2.3.0, Keras 2.4.3 # 导入模块 import keras as K import pandas as pd from sklearn.model_selection import KFold # 读取CSV数据集 # 该函数的传入参数为csv_file_path: csv文件路径 def load_data(sv_file_path): iris = pd.read_csv(sv_file_path) target_var = 'class' # 目标变量 # 数据集的特征 features = list(iris.columns) features.remove(target_var) # 目标变量的类别 Class = iris[target_var].unique() # 目标变量的类别字典 Class_dict = dict(zip(Class, range(len(Class)))) # 增加一列target, 将目标变量转化为类别变量 iris['target'] = iris[target_var].apply(lambda x: Class_dict[x]) return features, 'target', iris # 创建模型 def create_model(): init = K.initializers.glorot_uniform(seed=1) simple_adam = K.optimizers.Adam() model = K.models.Sequential() model.add(K.layers.Dense(units=5, input_dim=4, kernel_initializer=init, activation='relu')) model.add(K.layers.Dense(units=6, kernel_initializer=init, activation='relu')) model.add(K.layers.Dense(units=3, kernel_initializer=init, activation='softmax')) model.compile(loss='sparse_categorical_crossentropy', optimizer=simple_adam, metrics=['accuracy']) return model def main(): # 1. 读取CSV数据集 print("Loading Iris data into memory") n_split = 10 features, target, data = load_data("./iris_data.csv") x = data[features] y = data[target] avg_accuracy = 0 avg_loss = 0 for train_index, test_index in KFold(n_split).split(x): print("test index: ", test_index) x_train, x_test = x.iloc[train_index], x.iloc[test_index] y_train, y_test = y.iloc[train_index], y.iloc[test_index] print("create model and train model") model = create_model() model.fit(x_train, y_train, batch_size=1, epochs=80, verbose=0) print('Model evaluation: ', model.evaluate(x_test, y_test)) avg_accuracy += model.evaluate(x_test, y_test)[1] avg_loss += model.evaluate(x_test, y_test)[0] print("K fold average accuracy: {}".format(avg_accuracy / n_split)) print("K fold average accuracy: {}".format(avg_loss / n_split)) main() 

模型的输出结果如下:

Iteration loss accuracy
1 0.00056 1.0
2 0.00021 1.0
3 0.00022 1.0
4 0.00608 1.0
5 0.21925 0.8667
6 0.52390 0.8667
7 0.00998 1.0
8 0.04431 1.0
9 0.14590 1.0
10 0.21286 0.8667
avg 0.11633 0.9600

10折交叉验证的平均loss为0.11633,平均准确率为96.00%。

总结

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/233872.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • volatile关键字作用

    volatile关键字作用一、作用简述内存可见性:保证变量的可见性:当一个被volatile关键字修饰的变量被一个线程修改的时候,其他线程可以立刻得到修改之后的结果。当一个线程向被volatile关键字修饰的变量写入数据的时候,虚拟机会强制它被值刷新到主内存中。当一个线程用到被volatile关键字修饰的值的时候,虚拟机会强制要求它从主内存中读取。 屏蔽JVM指令重排序(防止JVM编译源码生成class时使用重排序)…

    2022年6月1日
    48
  • 免费淘宝IP地址库简介及PHP/C#调用实例

    免费淘宝IP地址库简介及PHP/C#调用实例

    2021年10月17日
    49
  • 电脑显示器尺寸对照表_显示器选购攻略

    显示器是属于电脑的I/O设备,即输入输出设备。它可以分为CRT、LCD等多种。它是一种将一定的电子文件通过特定的传输设备显示到屏幕上再反射到人眼的显示工具。当用电脑来放松娱乐时,一个好的显示器则是必不可少的,看VCD时画面稳定;玩游戏时现场逼真,有一种身临其境的感觉,那种感觉一定特棒,这一切都取决于你选择的显示器品质的高低,对显示器的知识有一个综合的了解无疑会对你有所帮助,下面将就这一问…

    2022年4月4日
    558
  • MP4格式详解_mp4格式有哪些

    MP4格式详解_mp4格式有哪些一、mp4概述MP4文件中的所有数据都装在box(QuickTime中为atom)中,也就是说MP4文件由若干个box组成,每个box有类型和长度,可以将box理解为一个数据对象块。box中可以包含另一个box,这种box称为containerbox。一个MP4文件首先会有且只有一个“ftyp”类型的box,作为MP4格式的标志并包含关于文件的一些信息;之后会有且只有一个“moov”类型的box(MovieBox),它是一种containerbox,子box包含了媒体的metadata信息;MP4文

    2022年10月16日
    2
  • C语言课程设计图书管理系统_大一c语言课程设计模板

    C语言课程设计图书管理系统_大一c语言课程设计模板倾心原创,转载请备注原文地址,谢谢。主要内容:图书信息包括:书名、作者名、ISBN号、出版单位、出版年份、价格等。试设计一个图书信息管理系统,使之能提供以下功能:(1)系统以菜单方式工作(2)图书信息录入功能(图书信息用文件保存)(3)图书信息浏览功能(4)查询和排序功能:(至少一种查询方式)(5)修改图书信息:对某图书信息进行修改(6)删除图书:将某图书的信息删除…

    2022年10月11日
    5
  • generic host process已停止工作_windows error reporting 1001

    generic host process已停止工作_windows error reporting 1001故障现象:今天在虚拟机里装了win2003系统,每次重启进入系统时都会报错:generichostprocessforwin32services遇到了一个问题需要关闭。解决方法:先从google查了下相关问题,觉得没一个说来符合我的实际情况。于是回头仔细查看日志,怀疑是安装文件太旧引起的。于是更新补丁,当安装完了提示的99个补丁后,再重启进入系统,…

    2022年10月11日
    11

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号