语义分割的评价指标_语义分割数据集

语义分割的评价指标_语义分割数据集包括:像素准确率、类别像素准确率、类别平均像素准确率、交并比、平均交并比、频权交并比。

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

前言

现存其实已经有很多博客实现了这个代码,但是可能不完整或者不能直接用于测试集的指标计算,这里简单概括一下。

一些概念、代码参考: [1] 憨批的语义分割9——语义分割评价指标mIOU的计算

[2]【语义分割】评价指标:PA、CPA、MPA、IoU、MIoU详细总结和代码实现(零基础从入门到精通系列!)

[3] 【语义分割】评价指标总结及代码实现

混淆矩阵

语义分割的各种评价指标都是基于混淆矩阵来的。

对于一个只有背景0和目标1的语义分割任务来说,混淆矩阵可以简单理解为:

TP(1被认为是1) FP(0被认为是1)
FN(1被认为是0) TN(0被认为是0)

各种指标的计算

1. 像素准确率 PA =(TP+TN)/(TP+TN+FP+FN)

2. 类别像素准确率 CPA = TP / (TP+FP)

3. 类别平均像素准确率 MPA = (CPA1+…+CPAn)/ n

4. 交并比 IoU = TP / (TP+FP+FN) 

5. 平均交并比 MIoU = (IoU1+…+IoUn) / n

6. 频权交并比 FWIoU =  [ (TP+FN) / (TP+FP+TN+FN) ] * [ TP / (TP + FP + FN) ]

代码实现

"""
https://blog.csdn.net/sinat_29047129/article/details/103642140
https://www.cnblogs.com/Trevo/p/11795503.html
refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
"""
import numpy as np
import os
from PIL import Image
__all__ = ['SegmentationMetric']

"""
confusionMetric
P\L     P    N

P      TP    FP

N      FN    TN

"""


class SegmentationMetric(object):
    def __init__(self, numClass):
        self.numClass = numClass
        self.confusionMatrix = np.zeros((self.numClass,) * 2) # 混淆矩阵n*n,初始值全0

    # 像素准确率PA,预测正确的像素/总像素
    def pixelAccuracy(self):
        # return all class overall pixel accuracy
        # acc = (TP + TN) / (TP + TN + FP + TN)
        acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
        return acc

    # 类别像素准确率CPA,返回n*1的值,代表每一类,包括背景
    def classPixelAccuracy(self):
        # return each category pixel accuracy(A more accurate way to call it precision)
        # acc = (TP) / TP + FP
        classAcc = np.diag(self.confusionMatrix) / np.maximum(self.confusionMatrix.sum(axis=1),1)
        return classAcc

    # 类别平均像素准确率MPA,对每一类的像素准确率求平均
    def meanPixelAccuracy(self):
        classAcc = self.classPixelAccuracy()
        meanAcc = np.nanmean(classAcc)
        return meanAcc

    # MIoU
    def meanIntersectionOverUnion(self):
        # Intersection = TP Union = TP + FP + FN
        # IoU = TP / (TP + FP + FN)
        intersection = np.diag(self.confusionMatrix)
        union = np.maximum(np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
            self.confusionMatrix), 1)
        IoU = intersection / union
        mIoU = np.nanmean(IoU)
        return mIoU

    # 根据标签和预测图片返回其混淆矩阵
    def genConfusionMatrix(self, imgPredict, imgLabel):
        # remove classes from unlabeled pixels in gt image and predict
        mask = (imgLabel >= 0) & (imgLabel < self.numClass)
        label = self.numClass * imgLabel[mask].astype(int) + imgPredict[mask]
        count = np.bincount(label, minlength=self.numClass ** 2)
        confusionMatrix = count.reshape(self.numClass, self.numClass)
        return confusionMatrix

    def Frequency_Weighted_Intersection_over_Union(self):
        # FWIOU =     [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
        freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix)
        iu = np.diag(self.confusionMatrix) / (
                np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) -
                np.diag(self.confusionMatrix))
        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU

    # 更新混淆矩阵
    def addBatch(self, imgPredict, imgLabel):
        assert imgPredict.shape == imgLabel.shape # 确认标签和预测值图片大小相等
        self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)

    # 清空混淆矩阵
    def reset(self):
        self.confusionMatrix = np.zeros((self.numClass, self.numClass))

def old():
    imgPredict = np.array([0, 0, 0, 1, 2, 2])
    imgLabel = np.array([0, 0, 1, 1, 2, 2])
    metric = SegmentationMetric(3)
    metric.addBatch(imgPredict, imgLabel)
    acc = metric.pixelAccuracy()
    macc = metric.meanPixelAccuracy()
    mIoU = metric.meanIntersectionOverUnion()
    print(acc, macc, mIoU)

def evaluate1(pre_path, label_path):
    acc_list = []
    macc_list = []
    mIoU_list = []
    fwIoU_list = []

    pre_imgs = os.listdir(pre_path)
    lab_imgs = os.listdir(label_path)

    for i, p in enumerate(pre_imgs):
        imgPredict = Image.open(pre_path+p)
        imgPredict = np.array(imgPredict)
        # imgPredict = imgPredict[:,:,0]
        imgLabel = Image.open(label_path+lab_imgs[i])
        imgLabel = np.array(imgLabel)
        # imgLabel = imgLabel[:,:,0]

        metric = SegmentationMetric(2) # 表示分类个数,包括背景
        metric.addBatch(imgPredict, imgLabel)
        acc = metric.pixelAccuracy()
        macc = metric.meanPixelAccuracy()
        mIoU = metric.meanIntersectionOverUnion()
        fwIoU = metric.Frequency_Weighted_Intersection_over_Union()

        acc_list.append(acc)
        macc_list.append(macc)
        mIoU_list.append(mIoU)
        fwIoU_list.append(fwIoU)

        # print('{}: acc={}, macc={}, mIoU={}, fwIoU={}'.format(p, acc, macc, mIoU, fwIoU))

    return acc_list, macc_list, mIoU_list, fwIoU_list

def evaluate2(pre_path, label_path):
    pre_imgs = os.listdir(pre_path)
    lab_imgs = os.listdir(label_path)

    metric = SegmentationMetric(2)  # 表示分类个数,包括背景
    for i, p in enumerate(pre_imgs):
        imgPredict = Image.open(pre_path+p)
        imgPredict = np.array(imgPredict)
        imgLabel = Image.open(label_path+lab_imgs[i])
        imgLabel = np.array(imgLabel)

        metric.addBatch(imgPredict, imgLabel)

    return metric

if __name__ == '__main__':
    pre_path = './pre_path/'
    label_path = './label_path/'

    # 计算测试集每张图片的各种评价指标,最后求平均
    acc_list, macc_list, mIoU_list, fwIoU_list = evaluate1(pre_path, label_path)
    print('final1: acc={:.2f}%, macc={:.2f}%, mIoU={:.2f}%, fwIoU={:.2f}%'
          .format(np.mean(acc_list)*100, np.mean(macc_list)*100,
                  np.mean(mIoU_list)*100, np.mean(fwIoU_list)*100))

    # 加总测试集每张图片的混淆矩阵,对最终形成的这一个矩阵计算各种评价指标
    metric = evaluate2(pre_path, label_path)
    acc = metric.pixelAccuracy()
    macc = metric.meanPixelAccuracy()
    mIoU = metric.meanIntersectionOverUnion()
    fwIoU = metric.Frequency_Weighted_Intersection_over_Union()
    print('final2: acc={:.2f}%, macc={:.2f}%, mIoU={:.2f}%, fwIoU={:.2f}%'
          .format(acc*100, macc*100, mIoU*100, fwIoU*100))

说明

1. 使用上述代码时只需修改pre_path和label_path即可。label_path是真实标签的路径,为8位图;pre_path是训练好模型后,测试集生成的分割结果的路径,也是8位图。

metric = SegmentationMetric(2) 中,2表示的是该分割图的类别总数,包含背景,需对应修改。

2. 上述给出了两种指标的计算方式。

evaluate1是对测试集中产生的每张预测图片都计算对应的各种指标,最后对所有图片的结果进行求均值;

evaluate2是把测试集中产生的每张预测图片的混淆矩阵都加在一起,成为一个整个的混淆矩阵,最后对这一个矩阵求各种指标。

3. 我的测试结果如下:

final1: acc=93.68%, macc=79.05%, mIoU=69.85%, fwIoU=89.09%
final2: acc=93.68%, macc=78.72%, mIoU=70.71%, fwIoU=88.88%

可以看到,两种计算方法的结果在PA上是相等的,因为都是 正确的像素/总像素 ,不管是先加总再除还是先除再取平均结果都是相同的;而其他指标结果值略有不同。

一般论文中使用的是第2种,当图片本身为1600×1200时,无论是直接对原图进行评估还是将其裁剪成12张400×400大小图片进行评估,第2种的计算结果相等,而第1种结果不同。

4. 如果要打印每个类别的IoU或PA,在对应方法中返回即可。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/171679.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • cBioPortal简介[通俗易懂]

    cBioPortal简介[通俗易懂]使用cBioCancerGenomicsPortal综合分析癌症基因和临床资料cBioCancerGenomicsPortal简介cBioCancerGenomicsPortal

    2022年8月3日
    8
  • pad图和n-s图_N S W

    pad图和n-s图_N S W(未完成_N-S图、PAD图概念未写)1、记录StudentRecord给出即将大学毕业的学生的姓名和平均分(GPA)。我们的目的是建立一个参加毕业典礼的学生表。候选毕业的学生表从文件”StudRecs”读入。因为学校规定:GPA低于minGPA的学生不能毕业,因此那些平均分低于minGPA的学生不参加毕业典礼。另外,记录那些选择不参加毕业典礼的学生名单,将名单保存至文件”NoAttend”中,按每行一个学生姓名保存。删除这些选择不参加毕业典礼的学生,最终生成参加毕业典礼的学生表。2、画出下列程序流

    2022年8月13日
    4
  • IDEA 控制台乱码 解决方法[通俗易懂]

    IDEA 控制台乱码 解决方法[通俗易懂]IDEA如果不进行配置的话,运行程序时控制台就会中文乱码,严重影响我们对信息的观察和程序的跟踪.非常的痛苦,那么上解决方法

    2025年5月26日
    3
  • 发票查验平台查询发票总显示系统繁忙的解决办法

    发票查验平台查询发票总显示系统繁忙的解决办法

    2021年11月20日
    46
  • chrome webdriver下载_webdriver.chrome()

    chrome webdriver下载_webdriver.chrome()请对应自己的谷歌浏览器的版本下载chrome的webdriver:点击下载windows环境变量配置1、webdriver文件位置可以自定义位置,如:d:\selenium环境变量,的文件夹下也可以放在C:\ProgramFiles(x86)\Google\Chrome\Application的文件夹下2、系统环境变量PATH按照图的指示,1->2->3->…

    2022年9月19日
    3
  • Spring中Responsebody注解的作用[通俗易懂]

    Spring中Responsebody注解的作用[通俗易懂]好长一段时间以来都只是写些测试代码,好久没写项目代码了,以至于sping那套东西日渐生疏了。最近在折腾一个小项目,写了一个controller用来响应ajax请求,结果断点调试发现一直返回"404…notresponse…",折腾了快2小时,一直没想到是注解的问题,万般无赖之下上了度娘,方才如梦初醒,特意记录一下,一来提醒一下自己,二来也让跟我遇到一样问题的朋友少受些折磨。这个注解表示…

    2022年5月8日
    69

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号