在pytorch中实现十折交叉验证

在pytorch中实现十折交叉验证本想在网上找个代码 看到大部分写的代码有点乱 有些直接自己把交叉验证代码撸了出来 也不知道对不对 我不敢用 然后我还是自己结合 sklearn 库的交叉验证接口来应用到 torch 中进行交叉验证 关于各种交叉验证方式介绍可看这里 Sklearn 中不同的数据抽样验证方式 下面以 10 折交叉验证为例 结合 sklearn 库 给出一个在 pytorch 中进行 10 折交叉验证的模板 deftrain passdefeval passimportco

本想在网上找个代码,看到大部分写的代码有点乱,有些直接自己把交叉验证代码撸了出来?,也不知道对不对,我不敢用。然后我还是自己结合sklearn库的交叉验证接口来应用到torch中进行交叉验证。
关于各种交叉验证方式介绍可看这里:Sklearn中不同的数据抽样验证方式

下面以10折交叉验证为例,结合sklearn库。给出一个在pytorch中进行10折交叉验证的模板:

def train_(): pass def eval_(): pass import collections history = collections.defaultdict(list) # 记录每一折的各种指标 from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold skf = KFold(n_splits=10,shuffle=True,random_state=42) #skf = StratifiedShuffleSplit(n_splits=10,test_size=0.1,random_state=42) #skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) for fold, (train_idx, val_idx) in enumerate(skf.split(x,y])): print(''*10,'第', fold+1, '折','ing....', ''*10) # for循环得到每一折的训练索引和验证索引,就可以对数据集进行抽取了。 # 抽取完之后,我们得到了训练数据和验证数据,那就分别转成torch的Dataset形式 # 然后再分别加载进torch的Dataloader里即可。 # 假设原数据集df存在Dataframe里,这里取出训练集和验证集 df_train = df.iloc[train_idx] df_train.index = range(len(df_train)) # 重置索引 df_val = df.iloc[val_idx] df_val.index = range(len(df_val)) # 重置索引 # 假设我们已经得到了训练数据的loader和验证数据的loader train_data_loader = ~ val_data_loader = ~ # 记住:每一折都要实例化新的模型,不然模型会学到测试集的东西 model = Model() for epoch in range(Epoch): print('——'*10, f'Epoch { 
     epoch + 1}/{ 
     EPOCHS}', '——'*10) train_(model, train_data_loader,...) metrics1, metrics2, ... = eval_(model, val_data_loader,...) history['metrics1'].append(metrics1) history['metrics2'].append(metrics2) . . . 剩下就是按你自己的逻辑写即可。 # 最后对每一折的结果取平均即可作为10折交叉验证的结果。 m1 = np.mean(history['metrics1']) m2 = np.mean(history['metrics2']) 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/224613.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 上午11:32
下一篇 2026年3月17日 上午11:32


相关推荐

  • ubuntu安装微信

    ubuntu安装微信nbsp 思路 1 下载安装 deepin wine 用于启动微信 2 下载安装微信 deb 可能会遇到的问题 1 盲目查看博客滴安装 如下面这段 这段安装是没有错的 错就错在安装的 deepin wine 版本为 2 18 12 这个版本不支持微信最新版本 nbsp deepin com wechat 2 6 8 65deepin0 i386 deb 安装必要的工具及 deepin wine 依赖 sudoaptinsta git

    2026年3月26日
    3
  • hadoop如何查看文件系统

    hadoop如何查看文件系统

    2021年8月25日
    67
  • Android ActivityManager一些API介绍

    Android ActivityManager一些API介绍Android 中 Java 层的 ActivityMana 类中封装了很多 API 可以供我们查询当前系统的很多信息 包括 内存 进程 Process 任务栈 Task 服务 Service 等的相关信息 利用这些信息可以进行一些有用的判断 例如判断当前系统内存是否不足 指定 Service 是否在运行中 ActivityMana 类封装了很多 API 方法供上层调用 具体负责管理 Activity Service 等组件的是 ActivityMana AMS

    2026年3月18日
    1
  • js中foreach和for循环的区别

    js中foreach和for循环的区别1 foreach 定义 foreach 又叫做增强 for 循环 相当于 for 循环的简化版 因此在一些较复杂的循环中不适用 结构 foreach 元素类型元素名称 循环对象 数组 集合 循环语句 特点 foreach 在循环次数未知或者计算起来较复杂的情况下效率比 for 循环高 2 foreach 与 for 循环的明显差别在于 foreach 循环时循环对象 数组 集合 被锁定 不能对循环对象中的内容进行增删改操作 3 实例 for 循环 可以修改循环语句 vara

    2026年3月16日
    2
  • NFS服务器搭建(配置web服务器)

    NFS服务简介什么是NFS?NFS挂载原理:RPC与NFS通讯原理:NFS客户端和NFS服务器通讯过程:Linux下NFS服务器部署NFS服务所需软件及主要配置文件:服务端安装NFS服务步骤:NFS客户端挂载配置:在Window上挂载NFS

    2022年4月13日
    86
  • PHP 使用 ElasticSearch 做搜索

    PHP 使用 ElasticSearch 做搜索

    2022年2月13日
    50

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号