pythonmachine-learningscikit-learnmultilabel-classificationprecision-recall

Calculating precision, recall and F1 score per class in a multilabel classification problem


I'm trying to calculate the precision, the recall and the F1-Score per class in my multilabel classification problem. However, I think I'm doing something wrong, because I am getting really high values, and the F1 Score for the whole problem is 0.66. However, I'm getting +0.8 f1-score in the individual classes.

This is how I am doing it right now:

confusion_matrix = multilabel_confusion_matrix(gold_labels, predictions)

assert(len(confusion_matrix) == 6)

for label in range(len(labels_reduced)):

    tp = confusion_matrix[label][0][0]
    fp = confusion_matrix[label][0][1]
    fn = confusion_matrix[label][1][0]
    tn = confusion_matrix[label][1][1]

    precision = tp+fp
    precision = tp/precision

    recall = tp+fn
    recall = tp/recall

    f1_score_up = precision * recall
    f1_score_down = precision + recall
    f1_score = f1_score_up/f1_score_down
    f1_score = 2 * f1_score

    print(f"Metrics for {labels_reduced[label]}.")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1-Score: {f1_score}")

Are these results okay? Do they make sense? Am I doing something wrong? How would you calculate those metrics? I'm using huggingface transformers for loading the models and getting the predictions, and sklearn for calculating the metrics.


Solution

  • You could use the classification_report function from sklearn:

    from sklearn.metrics import classification_report
    
    labels = [[0, 1, 1], [1, 0, 0], [1, 0, 1]]
    predictions = [[[0, 0, 1], [1, 0, 0], [1, 1, 1]]
    
    report = classification_report(labels, predictions)
    print(report)
    

    Which outputs:

                  precision    recall  f1-score   support
    
               0       1.00      1.00      1.00         2
               1       0.00      0.00      0.00         1
               2       1.00      1.00      1.00         2
    
       micro avg       0.80      0.80      0.80         5
       macro avg       0.67      0.67      0.67         5
    weighted avg       0.80      0.80      0.80         5
     samples avg       0.89      0.83      0.82         5