r

Creating one legend for multiple plots


I have this R code where I am making multiple 3d plots.

First, I defined some functions below to generate these plots:

library(plotly)
library(dplyr)
library(htmltools)

calculate_b_values <- function(p1, p2, p3, weights) {
    grid <- expand.grid(p1 = p1, p2 = p2)
    grid$B <- (weights$w4 / weights$w2) * weights$w3 * grid$p1 * grid$p2 * p3
    matrix(grid$B, nrow = length(p1), ncol = length(p2))
}

create_surface_plot <- function(p1, p2, z_matrix, p3_value) {
    plot_ly() %>% 
        add_surface(
            z = z_matrix,
            x = p1,
            y = p2,
            colorscale = "Viridis",
            opacity = 0.9
        ) %>% 
        layout(
            scene = list(
                xaxis = list(title = "p1"),
                yaxis = list(title = "p2"),
                zaxis = list(title = "B"),
                aspectratio = list(x = 1, y = 1, z = 0.7)
            ),
            title = paste("Surface Plot with p3 =", p3_value)
        )
}

generate_surface_plots <- function(p_range = seq(0.1, 1, length.out = 20),
                                   p3_values = c(0.1, 0.4, 0.7, 1.0),
                                   weights = list(w1 = 10, w2 = 5, w3 = 3, w4 = 8)) {
    plot_list <- list()
    for (i in seq_along(p3_values)) {
        p3_fixed <- p3_values[i]
        z_matrix <- calculate_b_values(p_range, p_range, p3_fixed, weights)
        plot_list[[i]] <- create_surface_plot(p_range, p_range, z_matrix, p3_fixed)
    }
    names(plot_list) <- paste0("p3_", p3_values)
    return(plot_list)
}

Then I called this function to create the plots:

weights <- list(w1 = 10, w2 = 5, w3 = 3, w4 = 8)
p_vals <- seq(0.1, 1, length.out = 20)
p3_values <- c(0.1, 0.4, 0.7, 1.0)

plots <- generate_surface_plots(p_vals, p3_values, weights)

Finally, I plotted all of them on the same page (I learned how to do this here Putting multiple plots on the same page in R?):

browsable(
    div(
        style = "display: grid; grid-template-columns: 1fr 1fr; gap: 10px;",
        div(style = "width: 300px; height: 250px;", plots[[1]]),
        div(style = "width: 300px; height: 250px;", plots[[2]]),
        div(style = "width: 300px; height: 250px;", plots[[3]]),
        div(style = "width: 300px; height: 250px;", plots[[4]])
    )
)

enter image description here

In the above plots, each plot has a different color scale as the ranges vary. Is it possible to have a single color scale for all plots?


Solution

  • You can pass cmin = min and cmax = max inside add_surface() to get a consistent color scale across all plots:

    add_surface(
        z = z_matrix,
        x = p1,
        y = p2,
        colorscale = "Viridis",
        cmin = 0,
        cmax = 4
    )
    

    A full solution using could be as follows (with thanks to Tim G):

    library(plotly)
    library(dplyr)
    library(htmltools)
    
    calculate_b_values <- function(p1, p2, p3, weights) {
      grid <- expand.grid(p1 = p1, p2 = p2)
      grid$B <- (weights$w4 / weights$w2) * weights$w3 * grid$p1 * grid$p2 * p3
      matrix(grid$B, nrow = length(p1), ncol = length(p2))
    }
    
    get_min_max <- function(p_range, p3_values, weights) {
      all_values <- c()
      for (p3 in p3_values) {
        z_matrix <- calculate_b_values(p_range, p_range, p3, weights)
        all_values <- c(all_values, z_matrix)
      }
      c(min(all_values), max(all_values))
    }
    
    create_surface_plot <- function(p1, p2, z_matrix, p3_value, min_max) {
      plot_ly() %>% 
        add_surface(
          z = z_matrix,
          x = p1,
          y = p2,
          colorscale = "Viridis",
          opacity = 0.9,
          cmin = min_max[1], 
          cmax = min_max[2] 
        ) %>% 
        layout(
          scene = list(
            xaxis = list(title = "p1"),
            yaxis = list(title = "p2"),
            zaxis = list(title = "B"),
            aspectratio = list(x = 1, y = 1, z = 0.7)
          ),
          title = paste("Surface Plot with p3 =", p3_value)
        )
    }
    
    generate_surface_plots <- function(p_range, p3_values, weights) {
      
      min_max <- get_min_max(p_range, p3_values, weights) # get min max
      
      plot_list <- list()
      for (i in seq_along(p3_values)) {
        p3_fixed <- p3_values[i]
        z_matrix <- calculate_b_values(p_range, p_range, p3_fixed, weights)
        plot_list[[i]] <- create_surface_plot(p_range, p_range, z_matrix, p3_fixed, min_max)
      }
      names(plot_list) <- paste0("p3_", p3_values)
      return(plot_list)
    }
    
    weights <- list(w1 = 10, w2 = 5, w3 = 3, w4 = 8)
    p_vals <- seq(0.1, 1, length.out = 20)
    p3_values <- c(0.1, 0.4, 0.7, 1.0)
    
    plots <- generate_surface_plots(p_vals, p3_values, weights)
    
    browsable(
      div(
        style = "display: grid; grid-template-columns: 1fr 1fr; gap: 10px;",
        div(style = "width: 300px; height: 250px;", plots[[1]]),
        div(style = "width: 300px; height: 250px;", plots[[2]]),
        div(style = "width: 300px; height: 250px;", plots[[3]]),
        div(style = "width: 300px; height: 250px;", plots[[4]])
      )
    )
    

    out