rggplot2gammgcvgratia

Cannot force legend to bottom in gratia::compare_smooths


I am using the wonderful gratia package to visualize multiple gam models from the mgcv package using the compare_smooths function. While I am able to create the gams properly and draw them, I cannot find a way to force the legend to the bottom of the figure.

My code is the following:

library(tidyverse)
library(gratia)
library(mgcv)
library(ggpubr)

gam1 <- gam(mpg ~ am + s(hp),data = mtcars |> filter(cyl >= 6))
gam2 <- gam(mpg ~ am + s(hp),data = mtcars |> filter(cyl < 6))

gam3 <- gam(mpg ~ am + s(hp),data = mtcars |> filter(qsec >= 18))
gam4 <- gam(mpg ~ am + s(hp),data = mtcars |> filter(qsec < 18))


plot1 <- draw(compare_smooths(gam1,gam2)) + theme(legend.position = "bottom")
plot2 <- draw(compare_smooths(gam3,gam4)) + theme(legend.position = "bottom")
              
gamplot <- ggarrange(plotlist = list(plot1,plot2),ncol = 2) 
gamplot

Which creates the following result:

plot1

This doesn't have the legend at the bottom, despite specifying it to be so. I also tried using the legend position from ggarrange but that didn't work either:

gamplot2 <- ggarrange(plotlist = list(plot1,plot2),ncol = 2,legend = "bottom") 
gamplot2

plot2

And even when the sets are by themeselves, this does not help:

plot1

plot3

I would like to put both the legends at the bottom. How can I do this? Is this possible within gratia using compare_smooths, or is this not doable?


Solution

  • The plot object returned by gratis::draw is a patchwork which under the hood uses guides="collect".

    Hence to apply theme modifications to each single ggplot object of the patch use & instead of +.

    library(gratia)
    library(mgcv)
    #> Loading required package: nlme
    #> This is mgcv 1.9-1. For overview type 'help("mgcv-package")'.
    library(ggpubr)
    #> Loading required package: ggplot2
    
    gam1 <- gam(mpg ~ am + s(hp), data = mtcars |> subset(cyl >= 6))
    gam2 <- gam(mpg ~ am + s(hp), data = mtcars |> subset(cyl < 6))
    
    gam3 <- gam(mpg ~ am + s(hp), data = mtcars |> subset(qsec >= 18))
    gam4 <- gam(mpg ~ am + s(hp), data = mtcars |> subset(qsec < 18))
    
    class(draw(compare_smooths(gam1, gam2)))
    #> [1] "patchwork" "gg"        "ggplot"
    
    str(draw(compare_smooths(gam1, gam2)))
    #> A patchwork composed of 1 patches
    #> - Autotagging is turned off
    #> - Guides are collected
    #> 
    #> Layout:
    #> 1 patch areas, spanning 1 columns and 1 rows
    #> 
    #>     t l b r
    #> 1:  1 1 1 1
    
    plot1 <- draw(compare_smooths(gam1, gam2)) & theme(legend.position = "bottom")
    plot2 <- draw(compare_smooths(gam3, gam4)) & theme(legend.position = "bottom")
    
    gamplot <- ggarrange(plotlist = list(plot1, plot2), ncol = 2)
    gamplot