pythonplotlyplotly-python

Plotly Express: Remove Trendline from Marginal Distribution Figures


I'm looking for a "clean" way to remove the trendline from the marginal-distribution subplot created using plotly-express. I know it's a bit unclear, so please look at the following example:
Generating some fake data:

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

Creating a scatter plot with both marginal and trendline options:

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)

This yields a figure with a trendline in all 3 panels:
original fig

I looked into the fig.data struct and found that the trendlines are the last 3 objects in it, and the last 2 are the lines appearing in the top & right panels. Removing those objects from the structs will result in removing the lines from those panels. Seen here:

fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]

last 2 objects removed from original fig.data

This creates a new issue, because it also removes trendline from the legend, which is not a behavior I'm happy with. So I need to first update the 3rd-to-last object (main panel's trendline) to have showlegend=True attribute:

fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]

This finally gives me the figure I wanted:
last 2 objects removed from original fig.data and trendline included in legend

So I do have a solution, but it requires "manhandling" the fig object.
Is there a better, cleaner way of achieving the same final figure?

###############
Full code:

import copy

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px

pio.renderers.default = "browser"

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)
fig.show()

fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
fig2.show()

fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
fig3.show()

Solution

  • You can use the Figure.update_traces() method that allows to apply specific properties to all traces that satisfy the selector parameter (there is no function to remove traces, but we can hide them using the visible property).

    All OLS trendline traces share the same name ("Overall Trendline", which is given by the trendline_scope), and you can use their xaxis (or yaxis) reference to distinguish between them (ie. "x" refers to the xaxis of the main subplot, "x2" and "x3" refer respectively to the right and the top axes/subplots).

    For example :

    np.random.seed(42)
    data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
    data["label"] = np.random.choice(list("ABC"), 100)
    data["is_outlier"] = np.random.choice([True, False], 100)
    
    fig = px.scatter(
        data, x="feature1", y="feature2",
        color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
        log_x=False, marginal_x="box",
        log_y=False, marginal_y="box",
        trendline="ols", trendline_scope="overall", trendline_color_override='black',
        trendline_options=dict(log_x=False, log_y=False),
    )
    
    fig.update_traces(visible=False, selector=dict(name='Overall Trendline'))
    fig.update_traces(visible=True, showlegend=True, selector=dict(name='Overall Trendline', xaxis='x'))
    
    fig.show()