半监督学习代码实战

半监督学习代码实战sklearn 官方例子 用半监督学习做数字识别什么是半监督学习半监督学习很重要 为什么呢 因为人工标注数据成本太高 现在大家参加比赛的数据都是标注好的了 那么如果老板给你一份没有标注的数据 而且有几百万条 让你做个分类什么的 你怎么办 不可能等标注好数据再去训练模型吧 所以你得会半监督学习算法 不过我在这里先打击大家一下 用 sklearn 的包做不了大数据量的半监督学习 我用的数据量大概在 15000 条以上就要报 MemoryError 错误了 这个是我最讨厌的错误 暂时我还没有解决的办法 如果

sklearn官方例子——用半监督学习做数字识别
 

什么是半监督学习

半监督学习很重要,为什么呢?因为人工标注数据成本太高,现在大家参加比赛的数据都是标注好的了,那么如果老板给你一份没有标注的数据,而且有几百万条,让你做个分类什么的,你怎么办?不可能等标注好数据再去训练模型吧,所以你得会半监督学习算法。

不过我在这里先打击大家一下,用sklearn的包做不了大数据量的半监督学习,我用的数据量大概在15000条以上就要报MemoryError错误了,这个是我最讨厌的错误。暂时我还没有解决的办法,如果同志们是小数据量,那就用这个做着玩玩吧

算法流程

假设我们有一份数据集,共330个数字,其中前十个是已知的,已经标注好了,后320个是未知的,需要我们预测出来的。

  • 首先把这330个数据全部都放到半监督学习算法里,训练模型,预测那320个标签
  • 然后用某种方法(看下面代码的操作)得知这320个数据里最不确定的前5个数据,对它进行人工标注,然后把它放到之前的10个数据里,现在就有15个已知数据了
  • 这样循环个几次,已标注的数据就变多了,那么分类器的效果肯定也就变好了

  • 一共330个点,都是已经标注好的了,我们把其中的320个点赋值为-1,这样就可以假装这320个点都是没有标注的了
  • 训练一个只有10个标记点的标签传播模型
  • 然后从所有数据中选择要标记的前五个最不确定的点,把它们(带有正确标签)放到原来的10个点中
  • 接下来可以训练15个标记点(原始10个 + 5个新点)
  • 重复这个过程四次,就可以使用30个标记好的点来训练模型
  • 可以通过改变max_iterations将这个值增加到30以上

 

LabelSpreading是一个半监督学习模型

 

import numpy as np import matplotlib.pyplot as plt from scipy import stats from sklearn import datasets from sklearn.semi_supervised import label_propagation from sklearn.metrics import classification_report,confusion_matrix # 再加下面这个,不然会报错 from scipy.sparse.csgraph import * digits = datasets.load_digits() rng = np.random.RandomState(0) # indices是随机产生的0-1796个数字,且打乱 #indices:[1081 1707 927 ... 1653 559 684] indices = np.arange(len(digits.data)) rng.shuffle(indices) # 取前330个数字来玩 X = digits.data[indices[:330]] y = digits.target[indices[:330]] images = digits.images[indices[:330]] n_total_samples = len(y) # 330 n_labeled_points = 10 # 标注好的数据共10条 max_iterations = 5 # 迭代5次 #未标注的数据320条 #即[10 11 12 ... 329] unlabeled_indices = np.arange(n_total_samples)[n_labeled_points:] f = plt.figure() # 画图用的 for i in range(max_iterations): if len(unlabeled_indices) == 0: print("no unlabeled items left to label") # 没有未标记的标签了,全部标注好了 break y_train = np.copy(y) y_train[unlabeled_indices] = -1 #把未标注的数据全部标记为-1,也就是后320条数据 lp_model = label_propagation.LabelSpreading(gamma=0.25,max_iter=5) # 训练模型 lp_model.fit(X,y_train) predicted_labels = lp_model.transduction_[unlabeled_indices] # 预测的标签 true_labels = y[unlabeled_indices] # 真实的标签 print('') print(predicted_labels) print(true_labels) print('') cm = confusion_matrix(true_labels,predicted_labels, labels = lp_model.classes_) print("iteration %i %s" % (i,70 * "_")) # 打印迭代次数 print("Label Spreading model: %d labeled & %d unlabeled (%d total)" % (n_labeled_points,n_total_samples-n_labeled_points,n_total_samples)) print(classification_report(true_labels,predicted_labels)) print("Confusion matrix") print(cm) # 计算转换标签分布的熵 # lp_model.label_distributions_作用是Categorical distribution for each item pred_entropies = stats.distributions.entropy( lp_model.label_distributions_.T) # 选择分类器最不确定的前5位数字的索引 # 首先计算出所有的熵,也就是不确定性,然后从320个中选择出前5个熵最大的 # numpy.argsort(A)提取排序后各元素在原来数组中的索引。具体情况可看下面 # np.in1d 用于测试一个数组中的值在另一个数组中的成员资格,返回一个布尔型数组。具体情况可看下面 uncertainty_index = np.argsort(pred_entropies)[::1] uncertainty_index = uncertainty_index[ np.in1d(uncertainty_index,unlabeled_indices)][:5] # 这边可以确定每次选前几个作为不确定的数,最终都会加回到训练集 # 跟踪我们获得标签的索引 delete_indices = np.array([]) # 可视化前5次的结果 if i < 5: f.text(.05,(1 - (i + 1) * .183), 'model %d\n\nfit with\n%d labels' % ((i + 1),i*5+10),size=10) for index,image_index in enumerate(uncertainty_index): # image_index是前5个不确定标签 # index就是0-4 image = images[image_index] # 可视化前5次的结果 if i < 5: sub = f.add_subplot(5,5,index + 1 + (5*i)) sub.imshow(image,cmap=plt.cm.gray_r) sub.set_title("predict:%i\ntrue: %i" % ( lp_model.transduction_[image_index],y[image_index]),size=10) sub.axis('off') # 从320条里删除要那5个不确定的点 # np.where里面的参数是条件,返回的是满足条件的索引 delete_index, = np.where(unlabeled_indices == image_index) delete_indices = np.concatenate((delete_indices,delete_index)) unlabeled_indices = np.delete(unlabeled_indices,delete_indices) # n_labeled_points是前面不确定的点有多少个被标注了 n_labeled_points += len(uncertainty_index) f.suptitle("Active learning with label propagation.\nRows show 5 most" "uncertain labels to learn with the next model") plt.subplots_adjust(0.12,0.03,0.9,0.8,0.2,0.45) plt.show() 

半监督学习代码实战

半监督学习代码实战

半监督学习代码实战

 

参考:

https://www.jianshu.com/p/a21817a81890

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/225539.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 上午9:05
下一篇 2026年3月17日 上午9:06


相关推荐

  • HttpClient4.x 文件上传

    HttpClient4.x 文件上传演示gradle包引入compilegroup:’org.apache.httpcomponents’,name:’httpclient’,version:’4.5.3’上传HttpClientHTTP_CLIENT=HttpClients.createDefault();HttpPosthttpPost=newHttpPost(“http://localhost/fi

    2022年7月22日
    11
  • js数组转字符串,字符串转数组的方式

    js数组转字符串,字符串转数组的方式1 数组转字符串 1 数组中 toString 方法能够把每个元素转换为字符串 然后以逗号连接输出显示 2 toLocalStrin 方法与 toString 方法用法基本相同 主要区别在于 toLocalStrin 方法能够使用用户所在地区特定的分隔符把生成的字符串连接起来 形成一个字符串 3 join 方法可以把数组转换为字符串 不过它可以指定分隔符 在调用 join 方法时 可以传递一个参数作为分隔符来连接每个元素 如果省略参数 默认使用逗号作为分隔符 这时与 t

    2026年3月19日
    2
  • Mysql 根据经纬度计算距离

    Mysql 根据经纬度计算距离方式 1 st distance sphereSELECT st distance sphere point lng lat point 116 40 0 asjuliFROMta 没用除以 1000 所以是以米为单位方式 2 st distanceSELE st distance point lng lat point 116 40 0 1

    2026年3月17日
    2
  • cocos creator编写2048小游戏,发微信小游戏

    cocos creator编写2048小游戏,发微信小游戏

    2021年3月12日
    183
  • python+tkinter实现GUI界面调用即梦AI文生图片API接口

    python+tkinter实现GUI界面调用即梦AI文生图片API接口

    2026年3月12日
    2
  • 从硬件到软件,低代码定制安灯(Andon)成为MES系统的全新增长点

    从硬件到软件,低代码定制安灯(Andon)成为MES系统的全新增长点安灯不是 安上灯泡 的缩写 而是一个制造业信息化的专有名词 通过安灯 现场工作人员可以快速上报生产中遇到的各种问题 如品质异常 设备故障 缺料等 反馈给其他工位和生产管理人员 让问题能够在第一时间得到处理 减少对生产过程的影响 随着精益制造的理念深入人心 制造业对安灯的要求也在日益增长 本文将为大家介绍安灯从硬件设备 到定制化软件模块的演进过程 探讨作为现代化的 MES 系统的核心模块 究竟什么样的安灯才能满足不同类型生产线的需求 发挥出 MES 的最大价值 安灯 图片来自网络 从拉线到触摸屏 看安灯的发

    2026年3月17日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号