rdplyrtidyverseyardstick

Tidyverse syntax for calculating precision and recall


I am trying to calculate AUC, Precision, Recall, Accuracy for every group in my data frame (i have a single data frame that has predicted data from three different models concatenated).

What is the tidyverse syntax to do it? I want to use the yardstick package by Max Kuhn to calculate these metrics.

Here is a sample df and here is where I got so far:

> library(tidyverse)
> library(yardstick)
> 
> sample_df <- data_frame(
+     group_type = rep(c('a', 'b', 'c'), each = 5),  # repeats each element 5 times
+     true_label = as.factor(rbinom(15, 1, 0.3)),    # generates 1 with 30% prob
+     pred_prob = runif(15, 0, 1)                    # generates 15 decimals between 0 and 1 from uniform dist
+ ) %>%
+     mutate(pred_label = as.factor(if_else(pred_prob > 0.5, 1, 0)))
> 
> sample_df
# A tibble: 15 x 4
   group_type true_label pred_prob pred_label
   <chr>      <fct>          <dbl> <fct>     
 1 a          1             0.327  0         
 2 a          1             0.286  0         
 3 a          0             0.0662 0         
 4 a          0             0.993  1         
 5 a          0             0.835  1         
 6 b          0             0.975  1         
 7 b          0             0.436  0         
 8 b          0             0.585  1         
 9 b          0             0.478  0         
10 b          1             0.541  1         
11 c          1             0.247  0         
12 c          0             0.608  1         
13 c          0             0.215  0         
14 c          0             0.937  1         
15 c          0             0.819  1         
> 

Metrics:

> # metrics for the full data
> precision(sample_df, truth = true_label, estimate = pred_label)
[1] 0.5714286
> recall(sample_df, truth = true_label, estimate = pred_label)
[1] 0.3636364
> accuracy(sample_df, truth = true_label, estimate = pred_label)
[1] 0.3333333
> roc_auc(sample_df, truth = true_label, pred_prob)
[1] 0.7727273
> 

Now how do i get these metrics for each group in my dataset??

sample_df %>%
    group_by(group_type) %>%
    summarize(???)

Solution

  • An example using unnest:

       sample_df %>% 
         group_by(group_type) %>% 
         do(auc = roc_auc(., true_label, pred_prob),
             acc = accuracy(., true_label, pred_label),
             recall = recall(., true_label, pred_label),
             precision = precision(., true_label, pred_label)) %>% unnest
    

    HOWEVER,

    I would actually suggest to not use yardstick because it doesn't play nice with dplyr summarize. Actually, it just uses the ROCR package under the hood. I would just make your own functions that take in two variables.

    yardstick is flawed because it requires a data.frame as it's first input, it is trying to be too clever. Under the dplyr framework, that isn't necessary because of summarize and mutate as functions already see the variables inside a data.frame without an explicit data parameter.