Python多分类问题pr曲线绘制(含代码)

Python多分类问题pr曲线绘制(含代码)研究了三天的多分类 pr 曲线问题终于在昨天晚上凌晨一点绘制成功了 现将所学所感记录一下 一来怕自己会忘可以温故一下 二来希望能给同样有疑惑的铁子们一些启迪 下图为我画的 pr 曲线 因为准确度超过了 97 所以曲线很饱和 首先了解一下二分类中的 pr 曲线是怎么画的 p 是 precition 是查准率 也是我们常用到的准确率 r 是 recall 是查全率 也叫召回率 上图为测试结果的混淆矩阵 表示一个数据集上的所有测试结果 其中竖列均为测试结果 即分类器预测概率大于 0 5 为正类 小于 0

研究了三天的多分类pr曲线问题终于在昨天晚上凌晨一点绘制成功了!!
现将所学所感记录一下,一来怕自己会忘可以温故一下,二来希望能给同样有疑惑的铁子们一些启迪!

首先了解一下二分类中的pr曲线是怎么画的?

混淆矩阵
上图为测试结果的混淆矩阵,表示一个数据集上的所有测试结果。
其中竖列均为测试结果,即分类器预测概率大于0.5为正类,小于0.5即为负类。
横列表示groundtruth,即真实的类别。


TP 表示正确分出正例的数量;
FN 表示把正例错分为反例的数量;
TN 表示正确分出反例的数量;
FP表示把反例错分为正例的数量。


准确率: P = TP /(TP+FP)
召回率: R = TP / (TP+FN)

上面是对精确率和召回率的简单介绍,下面进入正题!


当阈值变小时,更多样本会被测试成飞机,虚线下移。假设取极限,阈值为0,那么所有样本都会被预测为飞机,召回率最大,为1;而精确率为 5/10 等于0.5。同理,阈值变大,虚线上移,精确率会变高,但召回率反而变低。

在设置阈值的时候,有两种方法:
1、从0-1之间按照等间隔设置,比如0,0.1,0.2,…,0.9,1.0。这样能得到10组 “p” “r” 值。当然也可以把间隔设置的小一点,可以得到更多组 “p” “r” 值。
2、把所有样本的概率预测值从小到大排序去重,并以此数列分别为阈值,进行计算 “p” “r” 值,可以得到更多组 “p” “r” 值。

绘制多分类的pr曲线

首先多分类方法无法绘制标准的pr曲线。

对于多分类问题,针对每一种类别都可以得到相对应的精确率和召回率,这样多分类问题就可以得到多组 “p” “r” 值,(P1, R1), (P2, R2), …, (Pn, Rn)。

对此计算平均值,就可以得到一组平均精确率和召回率,又叫做 “宏精确率”macro-P)和 “宏召回率”(macro-R)。由此画得的曲线应该叫做 “宏pr曲线”

因此我们需要先确定出数据集的测试结果。包括测试集的ground-truth类别,预测类别,以及对于每个测试样本的预测概率,保存成.txt文件。下图为.txt文件中的部分数据:

2 2 0.0000 0.0000 1.0000 0 0 0.9748 0.0252 0.0000 1 1 0.0000 1.0000 0.0000 2 1 0.0000 0.7629 0.2371 0 0 0.9999 0.0001 0.0000 1 1 0.0000 0.9996 0.0004 2 2 0.0000 0.0360 0.9640 0 0 0.9954 0.0046 0.0000 1 1 0.0000 0.9997 0.0003 2 2 0.0000 0.0000 1.0000 1 1 0.0000 1.0000 0.0000 1 1 0.0000 1.0000 0.0000 0 0 0.9972 0.0028 0.0000 1 1 0.0000 1.0000 0.0000 

以三分类为例,上图是从左到右分别是标签预测类别、以及预测概率(三类别,因此有三个概率),这个根据自己的分类的情况,从预测结果中提取出来就行。

这是提取.txt文件的参考代码:

# 提取.txt文件的参考代码 clses是分类标签列表,preds是预测结果列表,pred_score是预测得分。 print("Saving files to txt....") with open("pr_curve.txt", 'w') as pr: for i in range(len(clses)): pr.write(str(clses[i]) + " " + str(preds[i]) + " " + str(format(pred_score[i][0], '.4f')) + " " + str(format(pred_score[i][1], '.4f')) + " " + str(format(pred_score[i][2], '.4f')) + "\n") print("All files have been written!") 

下面是计算“宏pr值”以及绘制pr曲线的代码(含注释):

import numpy as np import matplotlib.pyplot as plt score_path = "./pr_curve.txt" # 文件路径 with open(score_path, 'r') as f: files = f.readlines() # 读取文件 lis_all = [] for file in files: _, _, s1, s2, s3 = file.strip().split(" ") lis_all.append(s1) lis_all.append(s2) lis_all.append(s3) lis_order = sorted(set(lis_all)) # 记录所有得分情况,并去重从小到大排序,寻找各个阈值点 macro_precis = [] macro_recall = [] for i in lis_order: true_p0 = 0 # 真阳 true_n0 = 0 # 真阴 false_p0 = 0 # 假阳 false_n0 = 0 # 假阴 true_p1 = 0 true_n1 = 0 false_p1 = 0 false_n1 = 0 true_p2 = 0 true_n2 = 0 false_p2 = 0 false_n2 = 0 for file in files: cls, pd, n0, n1, n2 = file.strip().split(" ") # 分别计算比较各个类别的得分,分开计算,各自为二分类, # 最后求平均,得出宏pr if float(n0) >= float(i) and cls == '0': # 遍历所有样本,第0类为正样本,其他类为负样本, true_p0 = true_p0 + 1 # 大于等于阈值,并且真实为正样本,即为真阳, elif float(n0) >= float(i) and cls != '0': # 大于等于阈值,真实为负样本,即为假阳; false_p0 = false_p0 + 1 # 小于阈值,真实为正样本,即为假阴 elif float(n0) < float(i) and cls == '0': false_n0 = false_n0 + 1 if float(n1) >= float(i) and cls == '1': # 遍历所有样本,第1类为正样本,其他类为负样本 true_p1 = true_p1 + 1 elif float(n1) >= float(i) and cls != '1': false_p1 = false_p1 + 1 elif float(n1) < float(i) and cls == '1': false_n1 = false_n1 + 1 if float(n2) >= float(i) and cls == '2': # 遍历所有样本,第2类为正样本,其他类为负样本 true_p2 = true_p2 + 1 elif float(n2) >= float(i) and cls != '2': false_p2 = false_p2 + 1 elif float(n2) < float(i) and cls == '2': false_n2 = false_n2 + 1 prec0 = (true_p0+0.00000000001) / (true_p0 + false_p0 + 0.00000000001) # 计算各类别的精确率,小数防止分母为0 prec1 = (true_p1+0.00000000001) / (true_p1 + false_p1 + 0.00000000001) prec2 = (true_p2+0.00000000001) / (true_p2 + false_p2 + 0.00000000001) recall0 = (true_p0+0.00000000001)/(true_p0+false_n0 + 0.00000000001) # 计算各类别的召回率,小数防止分母为0 recall1 = (true_p1+0.00000000001) / (true_p1 + false_n1+0.00000000001) recall2 = (true_p2+0.00000000001)/(true_p2+false_n2 + 0.00000000001) precision = (prec0 + prec1 + prec2)/3 recall = (recall0 + recall1 + recall2)/3 # 多分类求得平均精确度和平均召回率,即宏macro_pr macro_precis.append(precision) macro_recall.append(recall) macro_precis.append(1) macro_recall.append(0) print(macro_precis) print(macro_recall) x = np.array(macro_recall) y = np.array(macro_precis) plt.figure() plt.xlim([-0.01, 1.01]) plt.ylim([-0.01, 1.01]) plt.xlabel('recall') plt.ylabel('precision') plt.title('PR curve') plt.plot(x, y) plt.show() 

代码是针对三分类写的,当然五分类多分类等,在代码里添加修改就可以了。

首先是了解pr曲线原理;
然后得到包含标签、预测类别和预测得分的.txt文件;
最后绘制pr曲线。

到这里多分类绘制pr曲线就介绍完毕了,早上写到了10点半,主要是想赶紧记下来,不然后面自己肯定又会懒惰了。

日常学习记录,一起交流讨论吧!侵权联系~

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/231749.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • KEIL / MDK生成BIN文件的两种方式「建议收藏」

    KEIL / MDK生成BIN文件的两种方式「建议收藏」KEIL/MDK生成BIN文件的两种方式1KEIL工程配置入口点击“魔术棒”图标(OptionforTarget)在AfterBuild/Rebuild选项卡中,勾选"Run

    2022年7月2日
    32
  • java中保留两位小数的输出

    java中保留两位小数的输出例如平方和与倒数和,最后输出要求保留两位小数,下面这个程序的注意点已用红色字体标记importjava.text.DecimalFormat;importjava.util.Scanner;publicclassMain{publicstaticvoidmain(String[]args){//TODOAuto-generatedmet

    2022年7月8日
    31
  • executeUpdate()返回值「建议收藏」

    executeUpdate()返回值「建议收藏」executeUpdate()是PreparedStatement对象众多的抽象方法中的一个在jdk1.8中executeUpdate()有两种返回值(搜索一下PreparedStatement即可查看到),即有两种返回值不同的方法。返回值分别是:int和ResultSet在使用数据库的项目中作者用到了返回值为int的方法,查询了一下相关的解释intexecuteUpdate()…

    2022年10月20日
    4
  • 详解Nacos和Eureka的区别「建议收藏」

    详解Nacos和Eureka的区别「建议收藏」1.Eureka是什么Eureka是SpringCloud微服务框架默认的也是推荐的服务注册中心,由Netflix公司与2012将其开源出来,Eureka基于REST服务开发,主要用于实现AWS云的中服务定位,以实现中间层服务器的负载均衡和故障转移,2018年7月份……………………………

    2022年8月21日
    14
  • java calendar获取年_Java Calendar获取年、月、日、时间,设置年、月、日

    java calendar获取年_Java Calendar获取年、月、日、时间,设置年、月、日JavaCalendar获取年、月、日、时间Calendarc=Calendar.getInstance(TimeZone.getTimeZone(“GMT+08:00”));//获取东八区时间intyear=c.get(Calendar.YEAR);//获取年intmonth=c.get(Calendar.MONTH)+1;//获取月份,0表示1月份intday=c.get(Calendar.DAY_OF_MONTH);//获取当前

    2022年5月18日
    51
  • JMeter参数化4种实现方式「建议收藏」

    JMeter参数化4种实现方式「建议收藏」前言下面是目前JMeter系列已有博文列表,持续更新中:JMeter安装与接口测试入门JMeter接口测试之断言实现JMeter逻辑控制器JMeter配置元件JMeter操作Mysql数据库BeanShellSampler与BeanShell断言JMeterLinux下执行测试JMeter自定义日志与日志分析1参数化释义什么是参数化?从字面上去理解的话,就是事先准备好数据(广义上来说,可以是具体的数据值,也可以是数据生成规则),而非在脚本中写死,脚本执行时从准备好的数据中

    2022年6月24日
    26

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号