rbart

Is there a way to retrieve the data from a BART package model in R?


I was wondering if there was a way to retrieve the data from a model built from the BART package in R?

It seems to be possible using other bart packages, such as dbarts... but I can't seem to find a way to get the original data back from a BART model. For example, if I create some data and run a BART and dbarts model, like so:

library(BART)
library(dbarts)

# create data
df <- data.frame(
  x = runif(100),
  y = runif(100),
  z = runif(100)
)

# create BART
BARTmodel <- wbart(x.train = df[,1:2],
                   y.train = df[,3])

# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
                    y.train = df[,3],
                    keeptrees = TRUE)

Using the keeptrees option in dbarts allows me to retrieve the data using:

# retrieve data from dbarts
DBARTSmodel$fit$data@x

However, there doesn't seem to be any type of similar option when using BART. Is it even possible to retrieve the data from a BART model?


Solution

  • The Value: section of ?wbart suggests it doesn't return the input as part of the output, and none of the function arguments for wbart suggest that this can be changed.

    Furthermore, if you look at the output of str, you can see that it's not present.

    library(BART)
    library(dbarts)
    
    # create data
    df <- data.frame(
      x = runif(100),
      y = runif(100),
      z = runif(100)
    )
    
    # create BART
    BARTmodel <- wbart(x.train = df[,1:2],
                       y.train = df[,3])
    
    # create dbarts
    DBARTSmodel <- bart(x.train = df[,1:2],
                        y.train = df[,3],
                        keeptrees = TRUE)
    
    str(BARTmodel)
    #> List of 13
    #>  $ sigma          : num [1:1100] 0.258 0.262 0.295 0.278 0.273 ...
    #>  $ yhat.train.mean: num [1:100] 0.584 0.457 0.505 0.54 0.403 ...
    #>  $ yhat.train     : num [1:1000, 1:100] 0.673 0.62 0.433 0.711 0.634 ...
    #>  $ yhat.test.mean : num(0) 
    #>  $ yhat.test      : num[1:1000, 0 ] 
    #>  $ varcount       : int [1:1000, 1:2] 109 114 111 118 115 114 115 110 114 117 ...
    #>   ..- attr(*, "dimnames")=List of 2
    #>   .. ..$ : NULL
    #>   .. ..$ : chr [1:2] "x" "y"
    #>  $ varprob        : num [1:1000, 1:2] 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 ...
    #>   ..- attr(*, "dimnames")=List of 2
    #>   .. ..$ : NULL
    #>   .. ..$ : chr [1:2] "x" "y"
    #>  $ treedraws      :List of 2
    #>   ..$ cutpoints:List of 2
    #>   .. ..$ x: num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
    #>   .. ..$ y: num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
    #>   ..$ trees    : chr "1000 200 2\n1\n1 0 0 0.01185590432\n3\n1 1 30 -0.01530736435\n2 0 0 0.01064412946\n3 0 0 0.02413784284\n3\n1 0 "| __truncated__
    #>  $ proc.time      : 'proc_time' Named num [1:5] 1.406 0.008 1.415 0 0
    #>   ..- attr(*, "names")= chr [1:5] "user.self" "sys.self" "elapsed" "user.child" ...
    #>  $ mu             : num 0.501
    #>  $ varcount.mean  : Named num [1:2] 115 110
    #>   ..- attr(*, "names")= chr [1:2] "x" "y"
    #>  $ varprob.mean   : Named num [1:2] 0.5 0.5
    #>   ..- attr(*, "names")= chr [1:2] "x" "y"
    #>  $ rm.const       : int [1:2] 1 2
    #>  - attr(*, "class")= chr "wbart"
    

    Whereas the output of str() for the bart output, while long, does contain the input:

    str(DBARTSmodel)
    #> List of 11
    #>  $ call           : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
    #>  $ first.sigma    : num [1:100] 0.289 0.311 0.268 0.253 0.242 ...
    #>  $ sigma          : num [1:1000] 0.288 0.307 0.248 0.257 0.293 ...
    #>  $ sigest         : num 0.295
    #>  $ yhat.train     : num [1:1000, 1:100] 0.715 0.677 0.508 0.51 0.827 ...
    #>  $ yhat.train.mean: num [1:100] 0.583 0.456 0.504 0.544 0.404 ...
    #>  $ yhat.test      : NULL
    #>  $ yhat.test.mean : NULL
    #>  $ varcount       : int [1:1000, 1:2] 128 118 120 142 130 145 145 150 138 138 ...
    #>   ..- attr(*, "dimnames")=List of 2
    #>   .. ..$ : NULL
    #>   .. ..$ : chr [1:2] "x" "y"
    #>  $ y              : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
    #>  $ fit            :Reference class 'dbartsSampler' [package "dbarts"] with 5 fields
    #>   ..$ pointer:<externalptr> 
    #>   ..$ control:Formal class 'dbartsControl' [package "dbarts"] with 18 slots
    #>   .. .. ..@ binary          : logi FALSE
    #>   .. .. ..@ verbose         : logi TRUE
    #>   .. .. ..@ keepTrainingFits: logi TRUE
    #>   .. .. ..@ useQuantiles    : logi FALSE
    #>   .. .. ..@ keepTrees       : logi TRUE
    #>   .. .. ..@ n.samples       : int 1000
    #>   .. .. ..@ n.burn          : int 100
    #>   .. .. ..@ n.trees         : int 200
    #>   .. .. ..@ n.chains        : int 1
    #>   .. .. ..@ n.threads       : int 1
    #>   .. .. ..@ n.thin          : int 1
    #>   .. .. ..@ printEvery      : int 100
    #>   .. .. ..@ printCutoffs    : int 0
    #>   .. .. ..@ rngKind         : chr "default"
    #>   .. .. ..@ rngNormalKind   : chr "default"
    #>   .. .. ..@ rngSeed         : int NA
    #>   .. .. ..@ updateState     : logi TRUE
    #>   .. .. ..@ call            : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
    #>   ..$ model  :Formal class 'dbartsModel' [package "dbarts"] with 9 slots
    #>   .. .. ..@ p.birth_death  : num 0.5
    #>   .. .. ..@ p.swap         : num 0.1
    #>   .. .. ..@ p.change       : num 0.4
    #>   .. .. ..@ p.birth        : num 0.5
    #>   .. .. ..@ node.scale     : num 0.5
    #>   .. .. ..@ tree.prior     :Formal class 'dbartsCGMPrior' [package "dbarts"] with 3 slots
    #>   .. .. .. .. ..@ power             : num 2
    #>   .. .. .. .. ..@ base              : num 0.95
    #>   .. .. .. .. ..@ splitProbabilities: num(0) 
    #>   .. .. ..@ node.prior     :Formal class 'dbartsNormalPrior' [package "dbarts"] with 0 slots
    #>  list()
    #>   .. .. ..@ node.hyperprior:Formal class 'dbartsFixedHyperprior' [package "dbarts"] with 1 slot
    #>   .. .. .. .. ..@ k: num 2
    #>   .. .. ..@ resid.prior    :Formal class 'dbartsChiSqPrior' [package "dbarts"] with 2 slots
    #>   .. .. .. .. ..@ df      : num 3
    #>   .. .. .. .. ..@ quantile: num 0.9
    #>   ..$ data   :Formal class 'dbartsData' [package "dbarts"] with 10 slots
    #>   .. .. ..@ y                    : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
    #>   .. .. ..@ x                    : num [1:100, 1:2] 0.152 0.666 0.967 0.248 0.668 ...
    #>   .. .. .. ..- attr(*, "dimnames")=List of 2
    #>   .. .. .. .. ..$ : NULL
    #>   .. .. .. .. ..$ : chr [1:2] "x" "y"
    #>   .. .. .. ..- attr(*, "drop")=List of 2
    #>   .. .. .. .. ..$ x: logi FALSE
    #>   .. .. .. .. ..$ y: logi FALSE
    #>   .. .. .. ..- attr(*, "term.labels")= chr [1:2] "x" "y"
    #>   .. .. ..@ varTypes             : int [1:2] 0 0
    #>   .. .. ..@ x.test               : NULL
    #>   .. .. ..@ weights              : NULL
    #>   .. .. ..@ offset               : NULL
    #>   .. .. ..@ offset.test          : NULL
    #>   .. .. ..@ n.cuts               : int [1:2] 100 100
    #>   .. .. ..@ sigma                : num 0.295
    #>   .. .. ..@ testUsesRegularOffset: logi NA
    #>   ..$ state  :List of 1
    #>   .. ..$ :Formal class 'dbartsState' [package "dbarts"] with 6 slots
    #>   .. .. .. ..@ trees     : int [1:1055] 0 18 -1 0 49 -1 -1 0 60 -1 ...
    #>   .. .. .. ..@ treeFits  : num [1:100, 1:200] -0.02252 0.00931 0.00931 0.02688 0.00931 ...
    #>   .. .. .. ..@ savedTrees: int [1:2340360] 0 797997482 1070928224 1 -402902351 1070268808 -1 -1094651769 -1081938039 -1 ...
    #>   .. .. .. ..@ sigma     : num 0.297
    #>   .. .. .. ..@ k         : num 2
    #>   .. .. .. ..@ rng.state : int [1:18] 0 1078575104 0 1078575104 -1657977906 1075613906 0 1078558720 277209871 -1068236140 ...
    #>   .. ..- attr(*, "runningTime")= num 0.477
    #>   .. ..- attr(*, "currentNumSamples")= int 1000
    #>   .. ..- attr(*, "currentSampleNum")= int 0
    #>   .. ..- attr(*, "numCuts")= int [1:2] 100 100
    #>   .. ..- attr(*, "cutPoints")=List of 2
    #>   .. .. ..$ : num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
    #>   .. .. ..$ : num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
    #>   ..and 40 methods, of which 26 are  possibly relevant:
    #>   ..  copy#envRefClass, getLatents, getPointer, getTrees, initialize, plotTree,
    #>   ..  predict, printTrees, run, sampleNodeParametersFromPrior,
    #>   ..  sampleTreesFromPrior, setControl, setCutPoints, setData, setModel,
    #>   ..  setOffset, setPredictor, setResponse, setSigma, setState, setTestOffset,
    #>   ..  setTestPredictor, setTestPredictorAndOffset, setWeights,
    #>   ..  show#envRefClass, storeState
    #>  - attr(*, "class")= chr "bart"