rggplot2lasso-regression

How to redraw a LASSO regression plot using ggplot?


After the LASSO regression is constructed, plot(fit) can be employed to plot the LASSO graph. But I want to redraw this graph using ggplot and beautify it. I use the following code to redraw it, but the redrawn LASSO graph and the graph drawn by plot(fit) are inconsistent. The x-axis interval and the plot(fit) are different, and the lines are not exactly the same. What's the problem? In addition, how to add a secondary x-axis (x-axis) to the redrawn ggplot graph to display the number of variables, similar to the plot(fit)? My code is as follows:

library(glmnet)
library(dplyr)
library(ggplot2)
x <- matrix(rnorm(100 * 20), 100, 20)
y <- sample(1:2, 100, replace = TRUE)
fit <- glmnet(x, y, family = "binomial")

# Original graph
plot(fit2, xvar = "lambda", label = T)
tidied <- broom::tidy(fit) %>% filter(term!= "(Intercept)")

# Redraw with ggplot
ggplot(tidied, aes(lambda, estimate, group = term, color = term)) +
  geom_line() +
  scale_x_log10()

Solution

  • To get you started, you should recognize that you need natural log, not log10.

    tidied <- broom::tidy(fit) %>% filter(term!= "(Intercept)") %>% 
      mutate(lnlambda = log(lambda))
    

    You can then generate a ggplot version like this:

    # Redraw with ggplot
    min_lambda = min(tidied$lnlambda)
    ggplot(tidied, aes(lnlambda, estimate, group = term, color = term)) +
      geom_line() + 
      geom_text(data = slice_min(tidied, lnlambda, by=term),
                aes(label=substr(term,2, length(term)), color=term, x=min_lambda, y=estimate),
                nudge_x =-.1, size=2
      ) + 
      geom_text(data = slice_min(
        data.frame(df = fit$df, lambda=fit$lambda) %>% filter(df %in% c(20,19,17,13,6)),
        lambda, 
        by=df),
        aes(label=df,x=log(lambda), y=0.5),inherit.aes = FALSE
      ) +
      scale_y_continuous(breaks= seq(-.4,.4, .2)) +
      theme(legend.position = "none") + 
      labs(x = "Log Lambda", y="Coefficients")
    

    Note: your example is not reproducible, because you did not set.seed() prior to sampling. This also means that the some of the hard-coded values in my solution will need to be adjusted. These include:

    The below seed and data generation, combined with the code above, leads to the plot shown below:

    set.seed(123456)
    x <- matrix(rnorm(100 * 20), 100, 20)
    y <- sample(1:2, 100, replace = TRUE)
    fit <- glmnet(x, y, family = "binomial")
    

    example_plot