rneural-networkloss-function

R package tabnet, how to change the loss to balanced accuracy?


In the tabnet package, I want the loss to be the balanced accuracy for multi-class classification. Similar to yardstick::bal_accuracy_vec(). How can I do that?

I do know how to compute a balanced accuracy but I don't know how to create a function that would fit in the tabnet framework. So basically, any help in how to customize the loss in tabnet is welcome.

Example:

library(tabnet)
library(recipes)

data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)

rec <-
  recipe(Attrition + JobSatisfaction ~ ., data = attrition[ids, ]) %>%
  step_normalize(all_numeric(),-all_outcomes())

attrition_fit <-
  tabnet_fit(rec,
             data = attrition[ids, ],
             epochs = 2,
             valid_split = 0.2,
             loss = yardstick::bal_accuracy_vec
             )

gives the error:Error in x != y : comparison (2) is possible only for atomic and list types.

Thanks for your help.


Solution

  • After asking the same thing on the tabnet guthub repo, I got the answer :

    Custom loss is supported as a function in {tabnet}

    The good news is that you have an example in here : https://github.com/mlverse/tabnet/blob/3d3ce9925cc12a13adf9e7b2c2c9ebd57c7f3d5e/tests/testthat/test-hardhat_parameters.R#L211-L229

    The bad news is that you must rely on {torch} loss, and you can not use {yardstick} metric as a loss. The reason is that the function must be differentiable to drive the gradient in the right direction...

    There is plenty of loss functions in {torch}, all ending with _loss. If one is missing, you can also build yours as a torch module like entmax and sparsemax in https://github.com/mlverse/tabnet/blob/main/R/sparsemax.R

    Hope it helps,