本文共 14250 字,大约阅读时间需要 47 分钟。
目录
前一段时间在做一个分类的项目,主要是应用ResNet18和MobileNetV2模型对数据进行分类,前者主要是用于GPU端,后者主要用于CPU端。模型分类效果主要是通过计算混淆矩阵以及准确率、召回率和F Score来分析,下面对以上指标进行详细的介绍。
视频讲解地址:
我们先来了解这些缩写的全称:T——True,P——Positive,F——False,N——Negative。故,
TP:真实值为 Positive, 预测值为 Positive,即真的正例;
FP:真实值为 Negative,预测值为 Positive,即假的正例;
FN:真实值为 Positive, 预测值为 Negative,即假的负例;
TN:真实值为 Negative,预测值为 Negative,即真的负例。
其中,在数学上,FP被称为第一类型错误(Type Ⅰ Error),FN被称为第二类型错误(Type Ⅱ Error)。
我们将TP、FP、FN、 TN画在一个表格里,这个表格就是混淆矩阵。
对于两分类,一个样本只有两种预测结果:正例或负例。
混淆矩阵 | 预测值 | ||
Positive | Negative | ||
真实值 | Positive | TP | FN |
Negative | FP | TN |
举个例子,我们要对猫和狗的图片进行分类,那么我们可以将猫设为正例,将狗设为负例,也就是说,最终的预测结果只有两种:是猫或不是猫(是狗)。假设猫狗的图片各有50和53张,猫狗图片预测正确的分别有45和47张,则混淆矩阵如表2所示
混淆矩阵 | 预测值 | ||
猫 | 狗 | ||
真实值 | 猫 | 45 | 5 |
狗 | 6 | 47 |
很多时候分类任务不仅仅是简单的二分类,可能是三分类或者多分类,这时的混淆矩阵和二分类的有所不同。N分类的混淆矩阵如表3所示
混淆矩阵 | 预测值 | ||||
class 1 | class 2 | ······ | class N | ||
真实值 | class 1 | ||||
class 2 | |||||
······ | |||||
class N |
举个例子,我们要对猫、狗和猪的图片进行分类,假设猫狗猪的图片各有51、52和49张,猫图片被预测为猫狗猪的图片数分别为47、1和3张,狗图片被预测为猫狗猪的图片数分别为1、49和2张,猪图片预测为猫狗猪的图片数分别为1、0和48张,则混淆矩阵如表4所示
混淆矩阵 | 预测值 | |||
猫 | 狗 | 猪 | ||
真实值 | 猫 | 47 | 1 | 3 |
狗 | 1 | 49 | 2 | |
猪 | 1 | 0 | 48 |
这三只被错误识别成猪的十有八九是橘猫。
针对每一个类别,我们也可以将其当作是二分类来分析,即预测结果是这类别或者不是这类别,单独画出其混淆矩阵。例如对于猫类别,被预测成狗和猪的结果可以统称为不是猫,这时对于猫的混淆矩阵如表5所示
混淆矩阵 | 预测值 | ||
猫 | 不是猫 | ||
真实值 | 猫 | 47 | 4 |
不是猫 | 2 | 99 |
二级指标主要有:准确率、精确率、召回率和特异度。
准确率:所有分类正确的结果占总观测值的比重。准确率是针对整个模型的,计算公式是:
(1)多分类模型
ACC = 分类正确的结果 / 总观测值
例如表4,分类准确率。
在多分类模型中,对于类别k,。
例如表5,三分类中猫的分类准确度为。
(2)二分类模型
例如表2,分类准确率。
精确率:在模型预测是Positive的所有结果中,模型预测对的比重,计算公式是:
在表2中,猫的分类精确率为;在表5中,三分类中猫的分类精确率为。
精确率对应着预测,简单来说是:“冤假错案”成本高,“漏网之鱼”成本低。举个例子,我们要判断邮件是否为垃圾邮件,是则True,否则False。如果一封垃圾邮件被误判断成正常邮件,那么我们可能只需要浪费几秒钟时间点开查看;但如果一封很重要的邮件被丢进垃圾箱里了,那可能会导致我们错过很重要的信息。这时候FP要尽量小,在TP不变的情况下,PPV要尽量大。
在信息检索领域,精确度也称为查准率。
召回率:在真实值是Positive的所有结果中,模型预测对的比重,计算公式是:
在表2中,猫的召回率为;在表5中,三分类中猫的分类召回率为。
召回率对应着样本(真实值),要求分类结果“大而全”,注重量,简单来说是:“冤假错案”成本低,“漏网之鱼”成本高。举个例子,我们要判断某一时间是否会发生地震,是则True,否则False。如果系统预测到今天会发生地震,提前发出预警,就算最终不发生地震,民众也就浪费点时间去避难;但如果真的发生地震了而没有预测出来,那就会导致人民的生命财产受到严重的损失。
在信息检索领域,召回率也称为查全率。
特异度:在真实值是Negative的所有结果中,模型预测对的比重,计算公式是:
F-measure是Precision和Recall的加权调和平均,计算公式是:
当α=1时,
F1 Score指标的取值范围是[0, 1],F1 Score越接近于0,模型的输出结果越差;F1 Score越接近于1,模型的输出结果越好。
代码主要是分析了ResNet18和MobileNetV2两个模型的三分类结果,统计混淆矩阵并打印输出,计算了二级和三级指标并打印输出,最后将混淆矩阵和二三级指标输出保存到Excel文件中,方便后续分析处理。
import osimport numpy as npimport xlwtimport shutil# 设置表格样式def set_style(name, height, bold=False): style = xlwt.XFStyle() font = xlwt.Font() font.name = name font.bold = bold font.color_index = 4 font.height = height style.font = font borders = xlwt.Borders() borders.left = 1 borders.right = 1 borders.top = 1 borders.bottom = 1 borders.bottom_colour=0x3A style.borders = borders return style# 写Exceldef write_excel(info_dict, res_name="./res.xls", cls_dict=""): f = xlwt.Workbook() for sheet_name, val_dict in info_dict.items(): sheet = f.add_sheet(sheet_name, cell_overwrite_ok=True) # Confusion Matrix # row0 = ["Confusion Matrix", "class 0", "class 1", ..., "class N", "Pass Rate (%)"] # colum0 = ["class 0", "class 1", ..., "class N", "model"] row0 = ["Confusion Matrix"] colum0 = [] for _, cls_name in cls_dict.items(): row0.append(cls_name) colum0.append(cls_name) row0.append("Pass Rate (%)") colum0.append("Model") # first row for i in range(0, len(row0)): sheet.write(0, i, row0[i], set_style('Times New Roman',220,True)) # first col for i in range(0, len(colum0)): sheet.write(i+1, 0, colum0[i], set_style('Times New Roman',220,True)) confusion_matrix = val_dict['confusion_matrix'] for row in range(confusion_matrix.shape[0]): for col in range(confusion_matrix.shape[1]): sheet.write(row+1, col+1, int(confusion_matrix[row][col]), set_style('Times New Roman',220,False)) # Accuracy of each class for row in range(confusion_matrix.shape[0]): if sum(confusion_matrix[row])*100 == 0: ACC = -1 else: ACC = round(confusion_matrix[row][row]/sum(confusion_matrix[row])*100, 2) sheet.write(row+1, confusion_matrix.shape[1]+1, ACC, set_style('Times New Roman',220,False)) # Accuracy of the model num_correct = 0 for cls_index in range(confusion_matrix.shape[0]): num_correct += confusion_matrix[cls_index][cls_index] ACC_model = round(num_correct/sum(sum(confusion_matrix))*100, 2) sheet.write(confusion_matrix.shape[0]+1, confusion_matrix.shape[1]+1, ACC_model, set_style('Times New Roman',220,False)) for i in range(1, confusion_matrix.shape[1]+1): sheet.write(confusion_matrix.shape[0]+1, i, '', set_style('Times New Roman',220,False)) sep = 2 # Index - Accuracy (ACC), Precision (PPV), Sensitivity (Recall, TPR), Specificity (TNR), F1-Score # first row first_row = confusion_matrix.shape[0] + 2 + sep # row0 = ["Index (%)", "class 0", "class 1", ..., "class N"] # colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"] row0 = ["Index (%)"] for _, cls_name in cls_dict.items(): row0.append(cls_name) colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"] for i in range(0, len(row0)): sheet.write(first_row, i, row0[i], set_style('Times New Roman',220,True)) # first col for i in range(0, len(colum0)): sheet.write(i+1+first_row, 0, colum0[i], set_style('Times New Roman',220,True)) index_list = val_dict['index'] for row in range(len(index_list)): for col in range(len(index_list[row])): sheet.write(col+1+first_row, row+1, round(index_list[row][col]*100, 2), set_style('Times New Roman',220,False)) sep = 1 # TP, TN, FP, FN of each class # first row first_col = confusion_matrix.shape[1] + 2 + sep row0 = ["Positive", "Negative"] colum0 = ["Positive", "Negative"] for cls_index in range(len(cls_dict.keys())): sheet.write(cls_index*len(colum0)+cls_index*2, first_col, cls_dict[cls_index], set_style('Times New Roman',220,True)) for i in range(0, len(row0)): sheet.write(cls_index*len(colum0)+cls_index*2, i+1+first_col, row0[i], set_style('Times New Roman',220,False)) # first col for i in range(0, len(colum0)): sheet.write(cls_index*len(colum0)+cls_index*2+i+1, first_col, colum0[i], set_style('Times New Roman',220,False)) # value - TP, FN, FP, TN value = val_dict[cls_dict[cls_index]] sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+1, int(value[0]), set_style('Times New Roman',220,False)) sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+2, int(value[1]), set_style('Times New Roman',220,False)) sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+1, int(value[2]), set_style('Times New Roman',220,False)) sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+2, int(value[3]), set_style('Times New Roman',220,False)) # sheet.write(1,3,'2006/12/12') # sheet.write_merge(6,6,1,3,'未知')#合并行单元格 # sheet.write_merge(1,2,3,3,'打游戏')#合并列单元格 # sheet.write_merge(4,5,3,3,'打篮球') # Confusion Matrix for each class f.save(res_name)def print_confusion_matrix(confusion_matrix, model_name='', cls_dict=''): num_cls = confusion_matrix.shape[0] print('') print('------------- ', model_name, ' Confusion Matrix -------------') print('row: target, col: predicted') # print('+'+'-'*47+'+') print('+'+'-------+'*(len(cls_dict.keys())+2)) print('|'+'\t', end='') for predict_index in range(num_cls): # if predict_index == num_cls - 1: # end_str = '\t' + '|' + '\n' # else: # end_str = '\t' end_str = '\t' print('|'+cls_dict[predict_index].rjust(6), end=end_str) print('|'+'Acc'.rjust(6), end=' |\n') print('+'+'-------+'*(len(cls_dict.keys())+2)) for target_index in range(num_cls): print('|'+cls_dict[target_index].rjust(6), end='\t') for predict_index in range(num_cls): # if predict_index == num_cls - 1: # end_str = '\t' + '|'+'\n' # else: # end_str = '\t' end_str = '\t' print('|'+str(confusion_matrix[target_index][predict_index]).rjust(6), end=end_str) print('|'+ "{:.2f}%".format(confusion_matrix[target_index][target_index]/sum(confusion_matrix[target_index])*100).rjust(6), end=' |\n') print('+'+'-------+'*(len(cls_dict.keys())+2)) # print('+'+'-'*39+'+')def cal_accuracy_rate(confusion_matrix, model_name='', cls_dict=''): num_total = np.sum(confusion_matrix) # accuracy rate num_accuracy = 0 for i in range(num_cls): num_accuracy += confusion_matrix[i][i] accuracy_rate = num_accuracy / num_total print('') print('------------- ', model_name, ' Accuracy Rate -------------') print('Number of correct prediction is ', num_accuracy) print('Number of test data is ', num_total) print('The accuracy rate is {:.2f}'.format(accuracy_rate*100)+'%') print('-------------------------------------------------------')def cal_other_index(confusion_matrix, res_dict={}, model_name='', cls_dict=''): res_dict['index'] = [] index_dict = {} num_cls = confusion_matrix.shape[0] print('') print('------------- ', model_name, ' ACC, PPV, TPR, TNR, F1-Score -------------') for cls_index in range(num_cls): index_dict[cls_dict[cls_index]] = {} # TR, TN, FP, FN TP = confusion_matrix[cls_index][cls_index] TN, FP, FN = 0, 0, 0 for target_index in range(num_cls): for predict_index in range(num_cls): if target_index == cls_index and predict_index != cls_index: FN += confusion_matrix[target_index][predict_index] elif predict_index == cls_index and target_index != cls_index: FP += confusion_matrix[target_index][predict_index] elif target_index != cls_index and predict_index != cls_index: TN += confusion_matrix[target_index][predict_index] index_dict[cls_dict[cls_index]]['TP'] = TP index_dict[cls_dict[cls_index]]['TN'] = TN index_dict[cls_dict[cls_index]]['FP'] = FP index_dict[cls_dict[cls_index]]['FN'] = FN # Accuracy ACC = (TP + TN) / (TP + TN + FP + FN) # Precision PPV = TP / (TP + FP) # Sensitivity (Recall) TPR = TP / (TP + FN) # Specificity TNR = TN / (TN + FP) # F1-Score F1_Score = 2 * PPV * TPR / (PPV + TPR) # TP, FN, FP, TN # res_dict[cls_dict[cls_index]] = [] res_dict[cls_dict[cls_index]] = [TP, FN, FP, TN] res_dict['index'].append([ACC, PPV, TPR, TNR, F1_Score]) # print results # TP, TN, FP, FN print('cls: '+cls_dict[cls_index], '\trow: target, col: predicted') print('+'+'-------+'*3) print('|'+'\t|', 'Pos'.rjust(6)+'|', 'Neg'.rjust(6)+'|') print('+'+'-------+'*3) print('|'+'Pos'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FN']).rjust(6)+' |') print('+'+'-------+'*3) print('|'+'Neg'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TN']).rjust(6)+' |') print('+'+'-------+'*3) print('Accuracy = {:.2f}'.format(ACC*100)+'%') print('Precision = {:.2f}'.format(PPV*100)+'%') print('Sensitivity (Recall) = {:.2f}'.format(TPR*100)+'%') print('Specificity = {:.2f}'.format(TNR*100)+'%') print('F1-Score = {:.2f}'.format(F1_Score*100)+'%') print('----------'*2) # print('Finished - ') if __name__ == '__main__': root_path = './results' save_path = './results/images' version_suffix = 'v4_3cls' cls_suffix = "blog" models_dict = {'resnet18_'+version_suffix+"_"+cls_suffix+'.txt':'resnet18_'+version_suffix, 'mobilenetv2_'+version_suffix+"_"+cls_suffix+'.txt':'mobilenetv2_'+version_suffix} xls_path = './results/excel/res_'+version_suffix+"_"+cls_suffix+'.xls' cls_dict = {0:'cat', 1:'dog', 2:'pig'} num_cls = len(cls_dict.keys()) res_dict = {} for res_name, model_name in models_dict.items(): res_dict[model_name] = {} # initialize matrix confusion_matrix = np.zeros((num_cls, num_cls), dtype = int) res_path = os.path.join(root_path, res_name) with open(res_path) as f_src: lines = f_src.readlines() scores_dict = {} for _, cls_name in cls_dict.items(): scores_dict[cls_name] = [] for line in lines: line_split = line.split('\n')[0].split(' ') img_name = line_split[0].split('/')[-1] target = int(line_split[1]) predict = int(line_split[2]) confusion_matrix[target][predict] += 1 # # save image classified error # if target != predict: # new_name = line_split[0].split("/")[-1] # new_path = os.path.join(save_path, model_name) # new_path = os.path.join(new_path, cls_suffix) # new_path = os.path.join(new_path, cls_dict[predict]) # new_path = os.path.join(new_path, new_name) # shutil.copyfile(line_split[0], new_path) res_dict[model_name]['confusion_matrix'] = confusion_matrix print_confusion_matrix(confusion_matrix, model_name, cls_dict) cal_accuracy_rate(confusion_matrix, model_name, cls_dict) cal_other_index(confusion_matrix, res_dict[model_name], model_name, cls_dict) write_excel(res_dict, xls_path, cls_dict)
在分类任务中,比较常用的分析指标有混淆矩阵、准确率、精确率和召回率,对于不同的任务,我们需要根据实际情况选择不同的方法提升不同的指标。
转载地址:http://mqani.baihongyu.com/