pytorch – K折交叉验证过程说明及实现

pytorch – K折交叉验证过程说明及实现代码主要核心思想来自 https www cnblogs com JadenFK3326 p 12164519 htmlK 折交叉交叉验证的过程如下 以 200 条数据 十折交叉验证为例子 十折也就是将数据分成 10 组 进行 10 组训练 每组用于测试的数据为 数据总条数 组数 即每组 20 条用于 valid 180 条用于 train 每次 valid 的都是不同的 1 将 200 条数据 分成按照数据

代码主要核心思想来自:https://www.cnblogs.com/JadenFK3326/p/12164519.html

K折交叉交叉验证的过程如下:

以200条数据,十折交叉验证为例子,十折也就是将数据分成10组,进行10组训练,每组用于测试的数据为:数据总条数/组数,即每组20条用于valid,180条用于train,每次valid的都是不同的。

(1)将200条数据,分成按照 数据总条数/组数(折数),进行切分。然后取出第i份作为第i次的valid,剩下的作为train

(2)将每组中的train数据利用DataLoader和Dataset,进行封装。

(3)将train数据用于训练,epoch可以自己定义,然后利用valid做验证。得到一次的train_loss和 valid_loss。

(4)重复(2)(3)步骤,得到最终的 averge_train_loss和averge_valid_loss

上述过程如下图所示:

pytorch - K折交叉验证过程说明及实现

上述的代码如下:

import torch import torch.nn as nn from torch.utils.data import DataLoader,Dataset import torch.nn.functional as F from torch.autograd import Variable 构造的训练集# x = torch.rand(100,28,28) y = torch.randn(100,28,28) x = torch.cat((x,y),dim=0) label =[1] *100 + [0]*100 label = torch.tensor(label,dtype=torch.long) 网络结构# class Net(nn.Module): #定义Net def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28*28, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 2) def forward(self, x): x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features #定义dataset# class TraindataSet(Dataset): def __init__(self,train_features,train_labels): self.x_data = train_features self.y_data = train_labels self.len = len(train_labels) def __getitem__(self,index): return self.x_data[index],self.y_data[index] def __len__(self): return self.len k折划分 def get_k_fold_data(k, i, X, y): 此过程主要是步骤(1) # 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据 assert k > 1 fold_size = X.shape[0] // k # 每份的个数:数据总条数/折数(组数) X_train, y_train = None, None for j in range(k): idx = slice(j * fold_size, (j + 1) * fold_size) #slice(start,end,step)切片函数 idx 为每组 valid X_part, y_part = X[idx, :], y[idx] if j == i: 第i折作valid X_valid, y_valid = X_part, y_part elif X_train is None: X_train, y_train = X_part, y_part else: X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接 y_train = torch.cat((y_train, y_part), dim=0) #print(X_train.size(),X_valid.size()) return X_train, y_train, X_valid,y_valid def k_fold(k, X_train, y_train, num_epochs=3,learning_rate=0.001, weight_decay=0.1, batch_size=5): train_loss_sum, valid_loss_sum = 0, 0 train_acc_sum ,valid_acc_sum = 0,0 for i in range(k): data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据 net = Net() 实例化模型 每份数据进行训练,体现步骤三# train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,\ weight_decay, batch_size) print('*'*25,'第',i+1,'折','*'*25) print('train_loss:%.6f'%train_ls[-1][0],'train_acc:%.4f\n'%valid_ls[-1][1],\ 'valid loss:%.6f'%valid_ls[-1][0],'valid_acc:%.4f'%valid_ls[-1][1]) train_loss_sum += train_ls[-1][0] valid_loss_sum += valid_ls[-1][0] train_acc_sum += train_ls[-1][1] valid_acc_sum += valid_ls[-1][1] print('#'*10,'最终k折交叉验证结果','#'*10) #体现步骤四 print('train_loss_sum:%.4f'%(train_loss_sum/k),'train_acc_sum:%.4f\n'%(train_acc_sum/k),\ 'valid_loss_sum:%.4f'%(valid_loss_sum/k),'valid_acc_sum:%.4f'%(valid_acc_sum/k)) 训练函数# def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size): train_ls, test_ls = [], [] 存储train_loss,test_loss dataset = TraindataSet(train_features, train_labels) train_iter = DataLoader(dataset, batch_size, shuffle=True) 将数据封装成 Dataloder 对应步骤(2) #这里使用了Adam优化算法 optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay) for epoch in range(num_epochs): for X, y in train_iter: 分批训练 output = net(X) loss = loss_func(output,y) optimizer.zero_grad() loss.backward() optimizer.step() 得到每个epoch的 loss 和 accuracy train_ls.append(log_rmse(0,net, train_features, train_labels)) if test_labels is not None: test_ls.append(log_rmse(1,net, test_features, test_labels)) #print(train_ls,test_ls) return train_ls, test_ls def log_rmse(flag,net,x,y): if flag == 1: valid 数据集 net.eval() output = net(x) result = torch.max(output,1)[1].view(y.size()) corrects = (result.data == y.data).sum().item() accuracy = corrects*100.0/len(y) # 5 是 batch_size loss = loss_func(output,y) net.train() return (loss.data.item(),accuracy) loss_func = nn.CrossEntropyLoss() 申明loss函 k_fold(10,x,label) k=10,十折交叉验证

上述代码中,直接按照顺序从x中每次截取20条作为valid,也可以先打乱然后在截取,这样效果应该会更好。如下所示:

import random import torch x = torch.rand(100,28,28) y = torch.randn(100,28,28) x = torch.cat((x,y),dim=0) label =[1] *100 + [0]*100 label = torch.tensor(label,dtype=torch.long) index = [i for i in range(len(x))] random.shuffle(index) x = x[index] label = label[index] 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/203432.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月19日 下午9:49
下一篇 2026年3月19日 下午9:49


相关推荐

  • Idea激活码最新教程2024.2.3版本,永久有效激活码,亲测可用,记得收藏

    Idea激活码最新教程2024.2.3版本,永久有效激活码,亲测可用,记得收藏Idea 激活码教程永久有效 2024 2 3 激活码教程 Windows 版永久激活 持续更新 Idea 激活码 2024 2 3 成功激活

    2025年5月30日
    6
  • 记忆化搜索算法

    记忆化搜索算法概述记忆化搜索算法事实上是一种对递归算法的优化因为在递归算法中有很多重复计算,导致了非常离谱的时间和空间复杂度所以我们采用记住计算结果的方式,能很大程度上减少复杂度例题1AcWing901.滑雪例题2AcWing2067.走方格…

    2022年7月26日
    18
  • python获取当前时间戳和日期

    python获取当前时间戳和日期importtime datetime 时间戳 print time time 今天的日期 print datetime date today 转载来源 https blog csdn net weixin article details

    2026年3月19日
    3
  • 锁定文件失败 未能启动虚拟机_win10无法分析无人应答文件

    锁定文件失败 未能启动虚拟机_win10无法分析无人应答文件首先看一下出现的错误:出现这个错误我也是纠结了好半天,试了网上的方法结果还是没有效果,比如下面的这个方法也不行,不知道是不是我机器的问题:后来误打误撞地把问题解决了,在创建虚拟机的最后一步将勾选的“创建后开启此虚拟机(P)”去的,即不勾选,创建完后再手动启动虚拟机,就可以了,如下图所示:…

    2025年11月14日
    8
  • Win系统 – 单通道 16G 内存 VS 双通道 16G 内存

    Win系统 – 单通道 16G 内存 VS 双通道 16G 内存单通道16GB测试成绩双通道16GB(8+8)测试成绩总结通过以上的一系列测试,不难看出单通道16GB与双通道16GB还是有一些差别的,究竟如何决择,笔者给大家分析一下。通过基础频率测试看出单通道16GB与双通道16GB内存条在性能参数、读取、写入、拷贝、复制、延迟及总体内存性能方面,还是存在着很大差距的;通过应用程序测试看出双通道16GB在解压缩方面比单通道16GB的速度要快接近1M/s,同理可以看出在双通道16GB在处理海量照片,视频软件等专业软件的能力要高出单通..

    2022年6月15日
    75
  • 异步传输模式atm实际上是两种交换技术的结合_异步转移模式ATM

    异步传输模式atm实际上是两种交换技术的结合_异步转移模式ATMAsynchronousTransferMode(ATM)异步传输模式(ATM)ATM是一项数据传输技术。它适用于局域网和广域网,它具有高速数据传输率和支持许多种类型如声音、数据、传真、实时视频、CD质量音频和图象的通信。ATM是在LAN或WAN上传送声音、视频图象和数据的宽带技术。它是一项信元中继技术,数据分组大小固定。你可将信元想像成一种运输设备,能够把数据块从一个设备经过ATM交

    2026年2月8日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号