rdecision-treepredictcaret

Error predicting with REEMtree model: Number of observations in newdata does not match group identifiers


I am using the REEMtree package in R to build a tree with random effects, but when I attempt to make predictions on the test data, I encounter the following error:

Error in predict.REEMtree(my_REEMtree, newdata = testing_data) : number of observations in newdata does not match the length of the group identifiers

Here's my reproducible code:

library(caret)
library(REEMtree)
library(rpart)

# Generate synthetic data with pupils in classes
n_classes <- 30
n_pupils <- 30

ds <- data.frame(
  x1 = rnorm(n_classes * n_pupils, 0, 1),
  z1 = rep(rnorm(n_classes, 0, 1), each = n_pupils)
)

ds$y = 1 + 2 * ds$x1 + 3 * ds$z1 + rnorm(n_classes * n_pupils, 0, 1)
ds$class_id <- as.factor(rep(1:n_classes, each = n_pupils))
# creates the class id 

# Split the data
set.seed(123) # For reproducibility
trainingRows <- createDataPartition(ds$y, p = .80, list = FALSE)
training_data <- ds[trainingRows,]
testing_data <- ds[-trainingRows,]

# Fit the model
my_REEMtree <- REEMtree(y ~ x1 + z1, data=training_data, random=~1|class_id, tree.control=rpart.control(cp=0.001))

# Predict with the model
predictions <- predict(my_REEMtree, newdata=testing_data)

# Perform checks to address error
print(levels(ds$class_id))
print(levels(training_data$class_id))
print(levels(testing_data$class_id))
sum(is.na(training_data$class_id))
sum(is.na(testing_data$class_id))
summary(testing_data$class_id)
all(levels(training_data$class_id) == levels(testing_data$class_id)) # returns TRUE

I have validated that there are no NA values in the class_id column, and the factor levels are the same across the training and testing sets.

I suspect the error is related to the internal expectations of the predict() function, but I have been unable to pinpoint the issue. Any insights into why this error is occurring and how to resolve it would be appreciated.


Solution

  • Got it.

    Fo the REEMtree to work you have to define the id in the predict function:

    predict(my_REEMtree, newdata=testing_data, id = testing_data$class_id)