rsegment

Why does R's predict segmented package not include effect of other covariates?


I have a fitted model using the segmented package. This consists of 1 segmented variable and some categorical variables. The problem is that when I use the model for prediction, it only predicts using the segmented variable.

The problem can be recreated with this code:

n=10
x=rep(seq(-3,3,l=n), 2)
z=c(rep(1, 10), rep(0, 10))
set.seed(1515)

y <- (x<0)*x/2 + 1 + 0.5*z + rnorm(x,sd=0.15)
segm <- segmented(lm(y ~ x + as.factor(z)), ~ x, psi=0.5)

newdf <- data.frame(expand.grid(x=seq(-3, 3, 0.5), z=c("1", "0")))
newdf$p1 <- predict(segm, newdata = newdf)
plot(newdf$x, newdf$p1)

You can see that the predict function returns the exact same value irrespective of the z variable value.

I would have expected the effect of the z variable to be included. I have tried extracting the components of the prediction with type="terms" which is what is in the package documentation but this doesn't seem to work either: Error in match.arg(type) : 'arg' should be one of “link”, “response”


Solution

  • In the version I have, (I know some in the comments couldn't replicate the problem) the problem appears to be when the function identifies the variables and corresponding coefficients to include in the prediction (you can follow along using debugonce(predict.segmented) until line 150-151:

      nomiOK <- intersect(names(estcoef.noV), colnames(X.noV))
    

    The column names of X.noV are the variable names in newdata and the names of estcoef.noV are the coefficient names. In the data, the variable is z, but the corresponding coefficient name is z1, so this doesn't work. Here are a couple of examples to show how this operates.

    Here's a way you could make it work. First, we'll build your data:

    library(segmented)
    n=10
    x=rep(seq(-3,3,l=n), 2)
    z=as.factor(c(rep(1, 10), rep(0, 10)))
    set.seed(1515)
    
    y <- (x<0)*x/2 + 1 + 0.5*(z == 2) + rnorm(x,sd=0.15)
    dat <- data.frame(x=x, z=z, y=y)
    

    Now, use the same segmented() call, but save the linear model that is used as input to segmented() as an object - you can do this inline.

    segm <- segmented(lmod <- lm(y ~ x + z, data=dat), ~ x, psi=0.5)
    

    Generate your new data as before, making z the appropriate factor.

    newdf <- data.frame(expand.grid(y = 0, 
                                    x=seq(-3, 3, 0.5), 
                                    z=factor(1:2, labels=c(0,1))))
    

    Here is where the difference comes - you also need to make the model matrix for the input linear model using the new data frame as the data. Note, you have to put values in the new data frame for the dependent variable, but they can be anything as they are not used in the construction of the model matrix. That's why I used y=0 above.

    X <- model.matrix(lmod, data=newdf)
    

    Now, attach to newer all the columns from X that aren't already there. You need to do this because the function checks that all variables in the formula are there (e.g., z has to be in the data), but later on when the calculations are done, z1 will also have to be there.

    newdf <- cbind(newdf, X[,setdiff(colnames(X), names(newdf))])
    

    Now, the predictions will be different.

    newdf$p1 <- predict(segm, newdata = newdf)
    newdf
    #>    y    x z (Intercept) z1          p1
    #> 1  0 -3.0 0           1  0 -0.64454838
    #> 2  0 -2.5 0           1  0 -0.35157664
    #> 3  0 -2.0 0           1  0 -0.05860491
    #> 4  0 -1.5 0           1  0  0.23436683
    #> 5  0 -1.0 0           1  0  0.52733857
    #> 6  0 -0.5 0           1  0  0.82031030
    #> 7  0  0.0 0           1  0  0.92298030
    #> 8  0  0.5 0           1  0  0.93049974
    #> 9  0  1.0 0           1  0  0.93801918
    #> 10 0  1.5 0           1  0  0.94553862
    #> 11 0  2.0 0           1  0  0.95305805
    #> 12 0  2.5 0           1  0  0.96057749
    #> 13 0  3.0 0           1  0  0.96809693
    #> 14 0 -3.0 1           1  1 -0.61113654
    #> 15 0 -2.5 1           1  1 -0.31816481
    #> 16 0 -2.0 1           1  1 -0.02519307
    #> 17 0 -1.5 1           1  1  0.26777866
    #> 18 0 -1.0 1           1  1  0.56075040
    #> 19 0 -0.5 1           1  1  0.85372214
    #> 20 0  0.0 1           1  1  0.95639214
    #> 21 0  0.5 1           1  1  0.96391157
    #> 22 0  1.0 1           1  1  0.97143101
    #> 23 0  1.5 1           1  1  0.97895045
    #> 24 0  2.0 1           1  1  0.98646989
    #> 25 0  2.5 1           1  1  0.99398933
    #> 26 0  3.0 1           1  1  1.00150877
    

    Created on 2024-01-18 with reprex v2.0.2