rggplot2forecastingfable-rautoplot

How to add legend for a single series in Fable autoplot and autolayer, similar to forecast::autolayer?


The forecast package version of autolayer has this nice series argument for adding a legend to one's forecast plots.

library(feasts)
library(forecast)
ETS <- forecast(ets(AirPassengers), h=5)

forecast::autoplot(AirPassengers) +
  forecast::autolayer(ETS, series="ETS", PI=FALSE)

enter image description here

I'm working with Fable now, and would like to replicate this legend.

Here's some test data I'm working with. (Their dput() will be available at the bottom of this question)

# > test
# A tsibble: 3 x 2 [1D]
  create_date HPTS.East
  <date>          <dbl>
1 2025-05-16         33
2 2025-05-17         50
3 2025-05-18         31

# > fcast
# A fable: 4 x 4 [1D]
# Key:     .model [2]
  .model create_date       HPTS.East .mean
  <chr>  <date>               <dist> <dbl>
1 ets    2025-05-19       N(5.3, 23)  5.35
2 ets    2025-05-20       N(5.1, 24)  5.05
3 stlf   2025-05-19  t(N(1.7, 0.42))  3.38
4 stlf   2025-05-20  t(N(2.2, 0.44))  5.12

When multiple fable forecast models are selected, autolayer will automatically add a legend.

library(fpp3)
autoplot(test) +
  autolayer(fcast |> filter(.model %in% c("stlf", "ets")), level = NULL) 

enter image description here

But not when only one model is selected...

autoplot(test) +
  autolayer(fcast |> filter(.model %in% c("stlf")), level = NULL) 

enter image description here

How can I get autolayer to display the legend when there is only one fable model selected?

I'm sure it's possible to manually manipulate the ggplot object and add a legend, but I feel like I'm just missing a simple argument call like: show_legend = TRUE or something similar.

# dput(test)
test =
structure(list(create_date = structure(c(20224, 20225, 20226), class = "Date"), 
    HPTS.East = c(33, 50, 31)), class = c("tbl_ts", "tbl_df", 
"tbl", "data.frame"), row.names = c(NA, -3L), key = structure(list(
    .rows = structure(list(1:3), ptype = integer(0), class = c("vctrs_list_of", 
    "vctrs_vctr", "list"))), class = c("tbl_df", "tbl", "data.frame"
), row.names = c(NA, -1L)), index = structure("create_date", ordered = TRUE), index2 = "create_date", interval = structure(list(
    year = 0, quarter = 0, month = 0, week = 0, day = 1, hour = 0, 
    minute = 0, second = 0, millisecond = 0, microsecond = 0, 
    nanosecond = 0, unit = 0), .regular = TRUE, class = c("interval", 
"vctrs_rcrd", "vctrs_vctr")))

# dput(fcast)
fcast = 
structure(list(.model = c("ets", "ets", "stlf", "stlf"), create_date = structure(c(20227, 
20228, 20227, 20228), class = "Date"), HPTS.East = structure(list(
    structure(list(mu = 5.34517049658451, sigma = 4.81792422458686), class = c("dist_normal", 
    "dist_default")), structure(list(mu = 5.050629714793, sigma = 4.89510333541553), class = c("dist_normal", 
    "dist_default")), structure(list(dist = structure(list(mu = 1.72153910975751, 
        sigma = 0.645715613426544), class = c("dist_normal", 
    "dist_default")), transform = function (HPTS.East) 
    HPTS.East^2, inverse = function (HPTS.East) 
    sqrt(HPTS.East)), class = c("dist_transformed", "dist_default"
    )), structure(list(dist = structure(list(mu = 2.16255346189791, 
        sigma = 0.665457675943415), class = c("dist_normal", 
    "dist_default")), transform = function (HPTS.East) 
    HPTS.East^2, inverse = function (HPTS.East) 
    sqrt(HPTS.East)), class = c("dist_transformed", "dist_default"
    ))), vars = "HPTS.East", class = c("distribution", "vctrs_vctr", 
"list")), .mean = c(5.34517049658451, 5.050629714793, 3.38064555985075, 
5.11947139402257)), class = c("fbl_ts", "tbl_ts", "tbl_df", "tbl", 
"data.frame"), row.names = c(NA, -4L), key = structure(list(.model = c("ets", 
"stlf"), .rows = structure(list(1:2, 3:4), ptype = integer(0), class = c("vctrs_list_of", 
"vctrs_vctr", "list"))), class = c("tbl_df", "tbl", "data.frame"
), row.names = c(NA, -2L), .drop = TRUE), index = structure("create_date", ordered = TRUE), index2 = "create_date", interval = structure(list(
    year = 0, quarter = 0, month = 0, week = 0, day = 1, hour = 0, 
    minute = 0, second = 0, millisecond = 0, microsecond = 0, 
    nanosecond = 0, unit = 0), .regular = TRUE, class = c("interval", 
"vctrs_rcrd", "vctrs_vctr")), response = "HPTS.East", dist = "HPTS.East", model_cn = ".model")

I looked through https://otexts.com/fpp3/ets-forecasting.html for its autoplot examples, but wasn't able to find a way to add the legend like the fpp2 textbook (forecast package).


Solution

  • The autoplot() and autolayer() methods for fable objects are designed to only produce colours when there are multiple .model values, and there is no option to force colours.

    These plot helpers are designed as quick analytical tools, to produce publication ready graphics I recommend constructing the plot using the grammar with ggplot2 and ggdist.

    # Import `dput()` data as above, removed for brevity
    library(fpp3)
    library(ggdist)
    
    fcast |> 
      ggplot(aes(ydist = HPTS.East, x = create_date)) + 
      stat_lineribbon(
        aes(colour = .model, fill = .model, fill_ramp = after_stat(.width)), 
        alpha = 0.3, .width = c(0.8, 0.95)
      ) +
      geom_line(aes(y = HPTS.East), data = test)
    

    forecast plot of multiple models with grammar

    Filtering forecasts from a single model…

    fcast |> 
      filter(.model == "ets") |>
      ggplot(aes(ydist = HPTS.East, x = create_date)) + 
      stat_lineribbon(
        aes(colour = .model, fill = .model, fill_ramp = after_stat(.width)), 
        alpha = 0.3, .width = c(0.8, 0.95)
      ) +
      geom_line(aes(y = HPTS.East), data = test)
    

    forecast plot of single model with grammar producing colour, fill and legend

    Created on 2025-05-28 with reprex v2.1.1