博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【深度学习】【机器学习】分类结果分析指标和方法(混淆矩阵、TP、TN、FP、FN、精确率、召回率)(附源码)
阅读量:4083 次
发布时间:2019-05-25

本文共 14250 字,大约阅读时间需要 47 分钟。

 

目录


0 前言

前一段时间在做一个分类的项目,主要是应用ResNet18和MobileNetV2模型对数据进行分类,前者主要是用于GPU端,后者主要用于CPU端。模型分类效果主要是通过计算混淆矩阵以及准确率、召回率和F Score来分析,下面对以上指标进行详细的介绍。

视频讲解地址:

 

1 分析指标 

1.1 TP、FP、FN、 TN

我们先来了解这些缩写的全称:T——True,P——Positive,F——False,N——Negative。故,

TP:真实值为  Positive, 预测值为 Positive,即真的正例;

FP:真实值为 Negative,预测值为 Positive,即假的正例;

FN:真实值为 Positive, 预测值为 Negative,即假的负例;

TN:真实值为 Negative,预测值为 Negative,即真的负例。

其中,在数学上,FP被称为第一类型错误(Type Ⅰ Error),FN被称为第二类型错误(Type Ⅱ Error)。

 

1.2 混淆矩阵

我们将TP、FP、FN、 TN画在一个表格里,这个表格就是混淆矩阵。

1.2.1 二分类

 

对于两分类,一个样本只有两种预测结果:正例或负例。

表1  二分类的混淆矩阵
混淆矩阵 预测值
Positive Negative
真实值 Positive TP FN
Negative FP TN

举个例子,我们要对猫和狗的图片进行分类,那么我们可以将猫设为正例,将狗设为负例,也就是说,最终的预测结果只有两种:是猫或不是猫(是狗)。假设猫狗的图片各有50和53张,猫狗图片预测正确的分别有45和47张,则混淆矩阵如表2所示

表2  猫狗二分类的混淆矩阵
混淆矩阵 预测值
真实值 45 5
6 47

1.2.2 多分类

很多时候分类任务不仅仅是简单的二分类,可能是三分类或者多分类,这时的混淆矩阵和二分类的有所不同。N分类的混淆矩阵如表3所示

表3  N分类的混淆矩阵
混淆矩阵 预测值
class 1 class 2 ······ class N
真实值 class 1        
class 2        
······        
class N        

举个例子,我们要对猫、狗和猪的图片进行分类,假设猫狗猪的图片各有51、52和49张,猫图片被预测为猫狗猪的图片数分别为47、1和3张,狗图片被预测为猫狗猪的图片数分别为1、49和2张,猪图片预测为猫狗猪的图片数分别为1、0和48张,则混淆矩阵如表4所示

表4 猫狗猪三分类的混淆矩阵
混淆矩阵 预测值
真实值 47 1 3
1 49 2
1 0 48

 

这三只被错误识别成猪的十有八九是橘猫。

针对每一个类别,我们也可以将其当作是二分类来分析,即预测结果是这类别或者不是这类别,单独画出其混淆矩阵。例如对于猫类别,被预测成狗和猪的结果可以统称为不是猫,这时对于猫的混淆矩阵如表5所示

表5  猫狗猪三分类中猫的混淆矩阵
混淆矩阵 预测值
不是猫
真实值 47 4
不是猫 2 99

 

1.3 二级指标

二级指标主要有:准确率、精确率、召回率和特异度。

1.3.1 准确率(Accuracy)

准确率:所有分类正确的结果占总观测值的比重。准确率是针对整个模型的,计算公式是:

(1)多分类模型

ACC = 分类正确的结果 / 总观测值

例如表4,分类准确率ACC =\tfrac{47+49+48}{51+52+49}\approx 94.74%

在多分类模型中,对于类别k,ACC_{class\: k} =\tfrac{TP_{class\: k}+TN_{class\: k}}{TP_{class\: k}+TN_{class\: k}+FP_{class\: k}+FN_{class\: k}}

例如表5,三分类中猫的分类准确度为ACC_{cat} =\tfrac{47+99}{47+99+2+4}\approx 96.05%

(2)二分类模型

ACC =\tfrac{TP+TN}{TP+TN+FP+FN}

例如表2,分类准确率ACC =\tfrac{45+47}{45+47+6+5}\approx 89.32%

 

1.3.2 精确率(Precission)

精确率:在模型预测是Positive的所有结果中,模型预测对的比重,计算公式是:

PPV =\tfrac{TP}{TP+FP}

在表2中,猫的分类精确率为PPV_{cat} =\tfrac{45}{45+6}\approx 88.24%;在表5中,三分类中猫的分类精确率为PPV_{cat} =\tfrac{47}{47+2}\approx 95.92%

精确率对应着预测,简单来说是:“冤假错案”成本高,“漏网之鱼”成本低。举个例子,我们要判断邮件是否为垃圾邮件,是则True,否则False。如果一封垃圾邮件被误判断成正常邮件,那么我们可能只需要浪费几秒钟时间点开查看;但如果一封很重要的邮件被丢进垃圾箱里了,那可能会导致我们错过很重要的信息。这时候FP要尽量小,在TP不变的情况下,PPV要尽量大。

在信息检索领域,精确度也称为查准率。

 

1.3.3 灵敏度(Sensitivity)/ 召回率(Recall)

召回率:在真实值是Positive的所有结果中,模型预测对的比重,计算公式是:

TPR =\tfrac{TP}{TP+FN}

在表2中,猫的召回率为TPR_{cat} =\tfrac{45}{45+5}\approx 90.00%;在表5中,三分类中猫的分类召回率为TPR_{cat} =\tfrac{47}{47+4}\approx 92.16%

召回率对应着样本(真实值),要求分类结果“大而全”,注重量,简单来说是:“冤假错案”成本低,“漏网之鱼”成本高。举个例子,我们要判断某一时间是否会发生地震,是则True,否则False。如果系统预测到今天会发生地震,提前发出预警,就算最终不发生地震,民众也就浪费点时间去避难;但如果真的发生地震了而没有预测出来,那就会导致人民的生命财产受到严重的损失。

在信息检索领域,召回率也称为查全率。

 

1.3.4 特异度(Specificity)

特异度:在真实值是Negative的所有结果中,模型预测对的比重,计算公式是:

TNR =\tfrac{TN}{TN+FP}

 

1.4 三级指标

 

1.4.1 F-measure

F-measure是Precision和Recall的加权调和平均,计算公式是:

F =\tfrac{\left (\alpha ^{2}+1 \right )\cdot Precision\cdot Recall}{\alpha ^{2}\cdot \left (Precision+Recall \right )}

 

1.4.2 F1-measure

当α=1时,F1\: Score =\tfrac{2\cdot Precision\cdot Recall}{Precision+Recall}

F1 Score指标的取值范围是[0, 1],F1 Score越接近于0,模型的输出结果越差;F1 Score越接近于1,模型的输出结果越好。

 

2 代码

 代码主要是分析了ResNet18和MobileNetV2两个模型的三分类结果,统计混淆矩阵并打印输出,计算了二级和三级指标并打印输出,最后将混淆矩阵和二三级指标输出保存到Excel文件中,方便后续分析处理。

2.1 源码

import osimport numpy as npimport xlwtimport shutil# 设置表格样式def set_style(name, height, bold=False):    style = xlwt.XFStyle()    font = xlwt.Font()    font.name = name    font.bold = bold    font.color_index = 4    font.height = height    style.font = font        borders = xlwt.Borders()    borders.left = 1    borders.right = 1    borders.top = 1    borders.bottom = 1    borders.bottom_colour=0x3A    style.borders = borders    return style# 写Exceldef write_excel(info_dict, res_name="./res.xls", cls_dict=""):    f = xlwt.Workbook()    for sheet_name, val_dict in info_dict.items():        sheet = f.add_sheet(sheet_name, cell_overwrite_ok=True)        # Confusion Matrix        # row0 = ["Confusion Matrix", "class 0", "class 1", ..., "class N", "Pass Rate (%)"]        # colum0 = ["class 0", "class 1", ..., "class N", "model"]        row0 = ["Confusion Matrix"]        colum0 = []        for _, cls_name in cls_dict.items():            row0.append(cls_name)            colum0.append(cls_name)        row0.append("Pass Rate (%)")        colum0.append("Model")                # first row        for i in range(0, len(row0)):            sheet.write(0, i, row0[i], set_style('Times New Roman',220,True))        # first col        for i in range(0, len(colum0)):            sheet.write(i+1, 0, colum0[i], set_style('Times New Roman',220,True))        confusion_matrix = val_dict['confusion_matrix']        for row in range(confusion_matrix.shape[0]):            for col in range(confusion_matrix.shape[1]):                sheet.write(row+1, col+1, int(confusion_matrix[row][col]), set_style('Times New Roman',220,False))        # Accuracy of each class        for row in range(confusion_matrix.shape[0]):            if sum(confusion_matrix[row])*100 == 0:                ACC = -1            else:                ACC = round(confusion_matrix[row][row]/sum(confusion_matrix[row])*100, 2)            sheet.write(row+1, confusion_matrix.shape[1]+1, ACC, set_style('Times New Roman',220,False))        # Accuracy of the model        num_correct = 0        for cls_index in range(confusion_matrix.shape[0]):            num_correct += confusion_matrix[cls_index][cls_index]        ACC_model = round(num_correct/sum(sum(confusion_matrix))*100, 2)        sheet.write(confusion_matrix.shape[0]+1, confusion_matrix.shape[1]+1, ACC_model, set_style('Times New Roman',220,False))        for i in range(1, confusion_matrix.shape[1]+1):            sheet.write(confusion_matrix.shape[0]+1, i, '', set_style('Times New Roman',220,False))        sep = 2        # Index - Accuracy (ACC), Precision (PPV), Sensitivity (Recall, TPR), Specificity (TNR), F1-Score        # first row        first_row = confusion_matrix.shape[0] + 2 + sep        # row0 = ["Index (%)", "class 0", "class 1", ..., "class N"]        # colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"]        row0 = ["Index (%)"]        for _, cls_name in cls_dict.items():            row0.append(cls_name)        colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"]        for i in range(0, len(row0)):            sheet.write(first_row, i, row0[i], set_style('Times New Roman',220,True))        # first col        for i in range(0, len(colum0)):            sheet.write(i+1+first_row, 0, colum0[i], set_style('Times New Roman',220,True))        index_list = val_dict['index']        for row in range(len(index_list)):            for col in range(len(index_list[row])):                sheet.write(col+1+first_row, row+1, round(index_list[row][col]*100, 2), set_style('Times New Roman',220,False))        sep = 1        # TP, TN, FP, FN of each class        # first row        first_col = confusion_matrix.shape[1] + 2 + sep        row0 = ["Positive", "Negative"]        colum0 = ["Positive", "Negative"]        for cls_index in range(len(cls_dict.keys())):            sheet.write(cls_index*len(colum0)+cls_index*2, first_col, cls_dict[cls_index], set_style('Times New Roman',220,True))            for i in range(0, len(row0)):                sheet.write(cls_index*len(colum0)+cls_index*2, i+1+first_col, row0[i], set_style('Times New Roman',220,False))            # first col            for i in range(0, len(colum0)):                sheet.write(cls_index*len(colum0)+cls_index*2+i+1, first_col, colum0[i], set_style('Times New Roman',220,False))            # value - TP, FN, FP, TN            value = val_dict[cls_dict[cls_index]]            sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+1, int(value[0]), set_style('Times New Roman',220,False))            sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+2, int(value[1]), set_style('Times New Roman',220,False))            sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+1, int(value[2]), set_style('Times New Roman',220,False))            sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+2, int(value[3]), set_style('Times New Roman',220,False))                # sheet.write(1,3,'2006/12/12')        # sheet.write_merge(6,6,1,3,'未知')#合并行单元格        # sheet.write_merge(1,2,3,3,'打游戏')#合并列单元格        # sheet.write_merge(4,5,3,3,'打篮球')        # Confusion Matrix for each class    f.save(res_name)def print_confusion_matrix(confusion_matrix, model_name='', cls_dict=''):    num_cls = confusion_matrix.shape[0]    print('')    print('------------- ', model_name, ' Confusion Matrix -------------')    print('row: target, col: predicted')    # print('+'+'-'*47+'+')    print('+'+'-------+'*(len(cls_dict.keys())+2))    print('|'+'\t', end='')    for predict_index in range(num_cls):        # if predict_index == num_cls - 1:        #     end_str = '\t' + '|' + '\n'        # else:        #     end_str = '\t'        end_str = '\t'        print('|'+cls_dict[predict_index].rjust(6), end=end_str)    print('|'+'Acc'.rjust(6), end=' |\n')    print('+'+'-------+'*(len(cls_dict.keys())+2))    for target_index in range(num_cls):        print('|'+cls_dict[target_index].rjust(6), end='\t')        for predict_index in range(num_cls):            # if predict_index == num_cls - 1:            #     end_str = '\t' + '|'+'\n'            # else:            #     end_str = '\t'            end_str = '\t'            print('|'+str(confusion_matrix[target_index][predict_index]).rjust(6), end=end_str)                print('|'+ "{:.2f}%".format(confusion_matrix[target_index][target_index]/sum(confusion_matrix[target_index])*100).rjust(6), end=' |\n')        print('+'+'-------+'*(len(cls_dict.keys())+2))    # print('+'+'-'*39+'+')def cal_accuracy_rate(confusion_matrix, model_name='', cls_dict=''):    num_total = np.sum(confusion_matrix)    # accuracy rate    num_accuracy = 0    for i in range(num_cls):        num_accuracy += confusion_matrix[i][i]    accuracy_rate = num_accuracy / num_total    print('')    print('------------- ', model_name, ' Accuracy Rate -------------')    print('Number of correct prediction is ', num_accuracy)    print('Number of test data is ', num_total)    print('The accuracy rate is {:.2f}'.format(accuracy_rate*100)+'%')    print('-------------------------------------------------------')def cal_other_index(confusion_matrix, res_dict={}, model_name='', cls_dict=''):    res_dict['index'] = []    index_dict = {}    num_cls = confusion_matrix.shape[0]    print('')    print('------------- ', model_name, ' ACC, PPV, TPR, TNR, F1-Score -------------')    for cls_index in range(num_cls):        index_dict[cls_dict[cls_index]] = {}        # TR, TN, FP, FN        TP = confusion_matrix[cls_index][cls_index]        TN, FP, FN = 0, 0, 0        for target_index in range(num_cls):            for predict_index in range(num_cls):                if target_index == cls_index and predict_index != cls_index:                    FN += confusion_matrix[target_index][predict_index]                elif predict_index == cls_index and target_index != cls_index:                    FP += confusion_matrix[target_index][predict_index]                elif target_index != cls_index and predict_index != cls_index:                    TN += confusion_matrix[target_index][predict_index]        index_dict[cls_dict[cls_index]]['TP'] = TP        index_dict[cls_dict[cls_index]]['TN'] = TN        index_dict[cls_dict[cls_index]]['FP'] = FP        index_dict[cls_dict[cls_index]]['FN'] = FN        # Accuracy        ACC = (TP + TN) / (TP + TN + FP + FN)        # Precision        PPV = TP / (TP + FP)        # Sensitivity (Recall)        TPR = TP / (TP + FN)        # Specificity        TNR = TN / (TN + FP)        # F1-Score        F1_Score = 2 * PPV * TPR / (PPV + TPR)        # TP, FN, FP, TN        # res_dict[cls_dict[cls_index]] = []        res_dict[cls_dict[cls_index]] = [TP, FN, FP, TN]        res_dict['index'].append([ACC, PPV, TPR, TNR, F1_Score])        # print results        # TP, TN, FP, FN        print('cls: '+cls_dict[cls_index], '\trow: target, col: predicted')        print('+'+'-------+'*3)        print('|'+'\t|', 'Pos'.rjust(6)+'|', 'Neg'.rjust(6)+'|')        print('+'+'-------+'*3)        print('|'+'Pos'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FN']).rjust(6)+' |')        print('+'+'-------+'*3)        print('|'+'Neg'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TN']).rjust(6)+' |')        print('+'+'-------+'*3)        print('Accuracy = {:.2f}'.format(ACC*100)+'%')        print('Precision = {:.2f}'.format(PPV*100)+'%')        print('Sensitivity (Recall) = {:.2f}'.format(TPR*100)+'%')        print('Specificity = {:.2f}'.format(TNR*100)+'%')        print('F1-Score = {:.2f}'.format(F1_Score*100)+'%')        print('----------'*2)        # print('Finished - ')        if __name__ == '__main__':    root_path = './results'    save_path = './results/images'    version_suffix = 'v4_3cls'    cls_suffix = "blog"    models_dict = {'resnet18_'+version_suffix+"_"+cls_suffix+'.txt':'resnet18_'+version_suffix, 'mobilenetv2_'+version_suffix+"_"+cls_suffix+'.txt':'mobilenetv2_'+version_suffix}    xls_path = './results/excel/res_'+version_suffix+"_"+cls_suffix+'.xls'    cls_dict = {0:'cat', 1:'dog', 2:'pig'}    num_cls = len(cls_dict.keys())    res_dict = {}        for res_name, model_name in models_dict.items():        res_dict[model_name] = {}        # initialize matrix        confusion_matrix = np.zeros((num_cls, num_cls), dtype = int)        res_path = os.path.join(root_path, res_name)        with open(res_path) as f_src:            lines = f_src.readlines()            scores_dict = {}            for _, cls_name in cls_dict.items():                scores_dict[cls_name] = []            for line in lines:                line_split = line.split('\n')[0].split(' ')                img_name = line_split[0].split('/')[-1]                target = int(line_split[1])                predict = int(line_split[2])                confusion_matrix[target][predict] += 1                # # save image classified error                # if target != predict:                #     new_name = line_split[0].split("/")[-1]                #     new_path = os.path.join(save_path, model_name)                #     new_path = os.path.join(new_path, cls_suffix)                #     new_path = os.path.join(new_path, cls_dict[predict])                #     new_path = os.path.join(new_path, new_name)                #     shutil.copyfile(line_split[0], new_path)        res_dict[model_name]['confusion_matrix'] = confusion_matrix        print_confusion_matrix(confusion_matrix, model_name, cls_dict)        cal_accuracy_rate(confusion_matrix, model_name, cls_dict)        cal_other_index(confusion_matrix, res_dict[model_name], model_name, cls_dict)    write_excel(res_dict, xls_path, cls_dict)

2.2 示例

2.2.1 打印结果

 2.2.2 Excel文件结果

 

3 总结

在分类任务中,比较常用的分析指标有混淆矩阵、准确率、精确率和召回率,对于不同的任务,我们需要根据实际情况选择不同的方法提升不同的指标。

 

 

 

 

 

 

转载地址:http://mqani.baihongyu.com/

你可能感兴趣的文章