I have a multiclass classification problem and want to build a precision-recall curve using pr_curve
from yardstick library in R. This function requires that a tibble with probabilities for each class were fed to it, like this (this is data(hpc_cv)
).
How do I get there from my classification results, stored as columns in a tibble?
library(yardstick)
data <- tibble(predicted = as.factor(c("A", "A", "B", "B", "C", "C")),
expected = as.factor(c("A", "B", "B", "C", "A", "C")))
data %>% conf_mat(truth = expected, estimate = predicted)
I have not found a function in yardstick (or elsewhere) to calculate those.
I am not sure how class probs are calculated, I am thinking along these lines:
data %>% filter(predicted == "A") %>% summarise(n = n() / 6)
Is this correct? If so, I wonder if there is a nice way to do it without for-loops on each class in each fold, and to receive a tibble like hpc_cv on the picture above.
I am not sure how class probs are calculated
Class probabilities are generated by a specific model for each individual data point.
PR curves (and precision and recall) are metrics for data sets where the outcome has two classes. You can do multiclass averaging to get an overall PR curve AUC though.
There is an example below but I would advise reading the tidymodels book for a bit before proceeding.
library(nnet) # <- for mutlinom_fit
library(tidymodels)
tidymodels_prefer()
data(hpc_data, package = "modeldata")
set.seed(1)
hpc_split <- initial_split(hpc_data)
hpc_train <- training(hpc_split)
hpc_test <- testing(hpc_split)
set.seed(2)
mutlinom_fit <-
multinom_reg() %>%
fit(class ~ iterations + compounds, data = hpc_train)
test_predictions <- augment(mutlinom_fit, new_data = hpc_test)
# examples of the hard class predictions and the
# predicted probabilities:
test_predictions %>% select(starts_with(".pred")) %>% head()
#> # A tibble: 6 × 5
#> .pred_class .pred_VF .pred_F .pred_M .pred_L
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 VF 0.641 0.279 0.0670 0.0128
#> 2 VF 0.640 0.280 0.0671 0.0128
#> 3 VF 0.628 0.287 0.0711 0.0138
#> 4 VF 0.628 0.287 0.0711 0.0138
#> 5 VF 0.626 0.288 0.0716 0.0139
#> 6 VF 0.626 0.288 0.0719 0.0140
# a confusion matrix
test_predictions %>% conf_mat(class, .pred_class)
#> Truth
#> Prediction VF F M L
#> VF 516 278 74 16
#> F 18 46 36 4
#> M 2 7 19 21
#> L 0 11 7 28
# create some metrics:
cls_metrics <- metric_set(accuracy, precision, recall, pr_auc)
# precision, recal, and the PR AUC are caluclated using macro weighting of 4
# different 1 vs all results.
# See https://yardstick.tidymodels.org/articles/multiclass.html
# evaluate them
test_predictions %>%
# See ?metric_set for more information. We pass the truth (class), all of the
# predicted probability columns (.pred_VF:.pred_L), and the named hard class
# predictions.
cls_metrics(class, .pred_VF:.pred_L, estimate = .pred_class)
#> # A tibble: 4 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.562
#> 2 precision macro 0.506
#> 3 recall macro 0.411
#> 4 pr_auc macro 0.481
Created on 2022-12-09 with reprex v2.0.2