pythonscikit-learnconfusion-matrix

Updating Confusion Matrix for Scikit-learn


I have been working on an Jupyter Notebook that takes in a CSV file, and manipulates it and produces various models and visual elements to describe them.

one tool im using is a confusion matrix from scikit-learn originally i used the plot_confusion_matrix function, however since updating through pip i have noticed this function has depreciated and removed. instead replaced with ConfusionMatrixDisplay

i am finding it difficult to switch over the function directly without gaining errors, does any body know how to rewrite them for current scikit-learn function?

def plot_confusMatrix(cm, classes,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):

    plt.rcParams.update({'font.size': 19})
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title,fontdict={'size':'16'})
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45,fontsize=12,color="blue")
    plt.yticks(tick_marks, classes,fontsize=12,color="blue")
    rc('font', weight='bold')
    fmt = '.1f'
    thresh = cm.max()
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="red")

    plt.ylabel('True label',fontdict={'size':'16'})
    plt.xlabel('Predicted label',fontdict={'size':'16'})
    plt.tight_layout()

plot_confusMatrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
                      title='Confusion matrix')

and

plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
                      title='Confusion matrix')

which used to produce but I cannot get to work with the new Confusion matrix function

enter image description here


Solution

  • try use this code

    cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=clf.classes_) disp.plot()
    plt.show()
    

    or look at this topic How can I plot a confusion matrix?