rmlr3nnet

Extract weights from fitted regr.nnet object in mlr3


This question is related to the solution provided by @Sebastian for a previous question. It showed how to do repeated training for a regr.nnet learner using a custom (=fixed) resampling strategy and cloned learners.

library(mlr3learners)
library(dplyr)
library(ggplot2)

set.seed(4123)
x <- 1:20
obs <- data.frame(
  x = rep(x, 3),
  f = factor(rep(c("a", "b", "c"), each = 20)),
  y = c(3 * dnorm(x, 10, 3), 5 * dlnorm(x, 2, 0.5), dexp(20 - x, .5)) + 
        rnorm(60, sd = 0.02)
)

x_test <- seq(0, 20, length.out = 100)
test <- expand.grid(
  x = x_test,
  f = c("a", "b", "c"),
  y = c(3 * dnorm(x_test, 10, 3), 5 * dlnorm(x_test, 2, 0.5), 
        dexp(20 - x_test, .5)) + rnorm(60, sd = 0.02)
)

dat <- rbind(obs, test)
task <- as_task_regr(dat, target = "y")
resampling <- rsmp("custom")
resampling$instantiate(task, list(train = 1:60), test = list(61:90060))
learner = lrn("regr.nnet", size=5, trace=FALSE)

learners <- replicate(10, learner$clone())
design <- benchmark_grid(
  tasks = task,
  learners = learners,
  resampling
)
bmr <- benchmark(design)

The next part now is to evaluate the benchmark further and to use the model for a further evaluation within and outside of mlr3. In the following, I tried to evaluate model performance and to plot predictions for the test data:

## evaluate quality criteria
bmr$aggregate()[learner_id == "regr.nnet"] # ok
bmr$aggregate(msr("time_train")) # works
# bmr$aggregate(msr("regr.rmse"), msr("regr.rsq"), msr("regr.bias")) # not possible

## select the best fit
i_best  <- which.min(bmr$aggregate()$regr.mse)
best    <- bmr$resample_result(i_best)

## do prediction
pr      <- as.data.table(best$predictions()[[1]])$response

## visualization
pred_test <-  test |>  mutate(y = pr)
ggplot(obs, aes(x, y)) + geom_point() +
  geom_line(data = pred_test, mapping = aes(x, y)) +
  facet_wrap(~f)

The R6 style has of course its advantages, I had been involved myself in the development of the pre-R6 proto package, but it is sometimes not so easy to find the best way to access internal data. The mlr3 book is very helpful, but questions remain:

  1. Is it easily possible to extract additional measures, e.g. msr("regr.rmse") from the benchmark object?
  2. I am not satisfied with my code line pr <- as.data.table() ....., but found no better way yet.
  3. Finally I want to get access internal data structure of the fitted nnet, to extract the raw weights for an "offline" use of the neural network outside of R.

Solution

    1. (added by myself) If you create teaching material using mlr3 and would like to share it you can create an issue or PR in the mlr-org/mlr3website repository. On the mlr-org website we have a resources tab where we can link stuff like that :) https://mlr-org.com/resources.html

    2. Is it easily possible to extract additional measures, e.g. msr("regr.rmse") from the benchmark object?

    The $aggregate() method takes in a list of measures (e.g. constructable by msrs()). (See code below)

    1. I am not satisfied with my code line pr <- as.data.table() ....., but found no better way yet.

    You can do best$predictions()[[1]]$response without the conversion.

    1. Finally I want to get access internal data structure of the fitted nnet, to extract the raw weights for an "offline" use of the neural network outside of R.

    We do not meddle with the internal structures of the fitted objects. They can be accessed through the $model slot of a trained learner (see code below).

    library(mlr3)
    library(mlr3learners)
    
    learner = lrns(c("classif.rpart", "classif.nnet"))
    task = tsk("iris")
    resampling = rsmp("holdout")
    
    design = benchmark_grid(
      tasks = task,
      learners = learner,
      resamplings = resampling
    )
    
    bmr = benchmark(design, store_models = TRUE)
    #> INFO  [07:13:52.802] [mlr3] Running benchmark with 2 resampling iterations
    #> INFO  [07:13:52.889] [mlr3] Applying learner 'classif.rpart' on task 'iris' (iter 1/1)
    #> INFO  [07:13:52.921] [mlr3] Applying learner 'classif.nnet' on task 'iris' (iter 1/1)
    #> # weights:  27
    #> initial  value 118.756408 
    #> iter  10 value 58.639749
    #> iter  20 value 45.676852
    #> iter  30 value 21.336083
    #> iter  40 value 8.646964
    #> iter  50 value 6.041813
    #> iter  60 value 5.906140
    #> iter  70 value 5.902865
    #> iter  80 value 5.898339
    #> final  value 5.898161 
    #> converged
    #> INFO  [07:13:52.946] [mlr3] Finished benchmark
    
    bmr$aggregate(msrs(c("classif.acc", "time_train")))
    #>    nr      resample_result task_id    learner_id resampling_id iters
    #> 1:  1 <ResampleResult[21]>    iris classif.rpart       holdout     1
    #> 2:  2 <ResampleResult[21]>    iris  classif.nnet       holdout     1
    #>    classif.acc time_train
    #> 1:        0.98      0.007
    #> 2:        1.00      0.005
    
    # get the first resample result
    rr1 = bmr$resample_result(1)
    
    # Get the model from the first resampling iteration of this ResampleResult
    rr1$learners[[1]]$model
    #> n= 100 
    #> 
    #> node), split, n, loss, yval, (yprob)
    #>       * denotes terminal node
    #> 
    #> 1) root 100 66 versicolor (0.32000000 0.34000000 0.34000000)  
    #>   2) Petal.Length< 2.45 32  0 setosa (1.00000000 0.00000000 0.00000000) *
    #>   3) Petal.Length>=2.45 68 34 versicolor (0.00000000 0.50000000 0.50000000)  
    #>     6) Petal.Width< 1.75 37  4 versicolor (0.00000000 0.89189189 0.10810811) *
    #>     7) Petal.Width>=1.75 31  1 virginica (0.00000000 0.03225806 0.96774194) *
    

    Created on 2023-03-08 with reprex v2.0.2