In the tabnet package, I want the loss to be the balanced accuracy for multi-class classification. Similar to yardstick::bal_accuracy_vec()
. How can I do that?
I do know how to compute a balanced accuracy but I don't know how to create a function that would fit in the tabnet framework. So basically, any help in how to customize the loss in tabnet is welcome.
Example:
library(tabnet)
library(recipes)
data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)
rec <-
recipe(Attrition + JobSatisfaction ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(),-all_outcomes())
attrition_fit <-
tabnet_fit(rec,
data = attrition[ids, ],
epochs = 2,
valid_split = 0.2,
loss = yardstick::bal_accuracy_vec
)
gives the error:Error in x != y : comparison (2) is possible only for atomic and list types
.
Thanks for your help.
After asking the same thing on the tabnet guthub repo, I got the answer :
Custom loss is supported as a function in {tabnet}
The good news is that you have an example in here : https://github.com/mlverse/tabnet/blob/3d3ce9925cc12a13adf9e7b2c2c9ebd57c7f3d5e/tests/testthat/test-hardhat_parameters.R#L211-L229
The bad news is that you must rely on {torch} loss, and you can not use {yardstick} metric as a loss. The reason is that the function must be differentiable to drive the gradient in the right direction...
There is plenty of loss functions in {torch}, all ending with _loss. If one is missing, you can also build yours as a torch module like entmax and sparsemax in https://github.com/mlverse/tabnet/blob/main/R/sparsemax.R
Hope it helps,