pythonpytorchauctorcheval

Correctly invoke TorchEval AUC calculation for two tensors


I am new to torcheval and trying to measure the AUC of my binary classifier (doc).

I notice that while classifier accuracy is decent, the AUC metric evaluates to below 0.5, which is incorrect (given that accuracy is better than 50/50 and my classes are balanced). The AUC also differs from sklearn.metrics.roc_auc_score. For a simple example:

from torcheval.metrics.aggregation.auc import AUC
from torcheval.metrics import BinaryAccuracy

from sklearn.metrics import roc_auc_score, accuracy_score

p_pred = torch.tensor([0.2, 0.3, 0.4, 0.6, 0.7, 0.8])  # model est likelihood of target class
y_true = torch.tensor([0.0, 0.0, 1.0, 0.0, 1.0, 1.0])  # ground truth, 1 = target class

# TorchEval Metrics
auc_metric = AUC()
accuracy_metric = BinaryAccuracy(threshold=0.5)
auc_metric.reset()
accuracy_metric.reset()
auc_metric.update(p_pred,y_true)
accuracy_metric.update(input=p_pred,target=y_true)

print(f"TorchEval Accuracy = {accuracy_metric.compute().item():.3}")
print(f"Sklearn Accuracy   = {accuracy_score(y_true=y_true,y_pred=p_pred.round()):.3}")
print(f"TorchEval AUC      = {auc_metric.compute().item():.3}")
print(f"Sklearn AUC        = {roc_auc_score(y_true=y_true,y_score=p_pred):.3}")

Return an unexpected value of TorchEval AUC:

TorchEval Accuracy = 0.667
Sklearn Accuracy   = 0.667
TorchEval AUC      = 0.3
Sklearn AUC        = 0.889

How can I correctly invoke TorchEval AUC to get the expected value of ~0.9?


Solution

  • I should have been using metrics.BinaryAUROC (doc) not metrics.AUC. I think AUC is for when you already have the coordinates of the ROC.

    Having different arguments for AUC.update() i.e., (x,y) instead of (input, target) like BinaryAccuracy should have been a giveaway that AUC wasn't what I wanted.

    from torcheval.metrics import BinaryAUROC
    auc_metric2 = BinaryAUROC()
    auc_metric2.reset()
    auc_metric2.update(input=p_pred,target=y_true)
    print(f"TorchEval AUC      = {auc_metric2.compute().item():.3}")
    

    returns the expected result.