pythonpandasplotlyplotly-pythonplotly.graph-objects

Plotly: How to show other values than counts for marginal histogram?


I am trying to create a linked marginal plot above the original plot, with the same x axis but with a different y axis.

I've seen that in plotly.express package there are 4 options in which you can create marginal_x plot on a scatter fig, but they are all based on the same columns as x and y.

In my case, I have a date on my x-axis and rate of something on my y-axis, and I am trying to produce a histogram marginal distribution plot of the samples of which this rate is based on (located in samples column within the df).

I'm simplifying what I've tried without lessening any important details:

import pandas as pd
import plotly.express as px

df = pd.DataFrame(
    {
        "date": [pd.Timestamp("20200102"), pd.Timestamp("20200103")],
        "rate": [0.88, 0.96],
        "samples": [130, 1200])
    }
)

fig = px.scatter(df, x='date', y='rate', marginal_x='histogram')
fig.show()

The documentation I based on: https://plotly.com/python/marginal-plots/

My desired result: Example:

enter image description here

The difference is that I use an aggregated df, so my count is just 1, instead of being the amount of samples.

Any ideas?

Thanks!


Solution

  • I'm understanding your statement

    [...] and rate of something on my y-axis

    ... to mean that you'd like to display a value on your histogram that is not count.

    marginal_x='histogram' in px.scatter() seems to be defaulted to show counts only, meaning that there is no straight-forward way to show values of individual observations. But if you're willing to use fig = make_subplots() in combination with go.Scatter() and go.Bar(), then you can easily build this:

    Plot

    enter image description here

    Complete code:

    import pandas as pd
    import numpy as np
    from datetime import datetime, timedelta
    from plotly.subplots import make_subplots
    import plotly.graph_objects as go
    
    fig = make_subplots(rows=2, cols=1,
                        row_heights=[0.2, 0.8],
                        vertical_spacing = 0.02,
                        shared_yaxes=False,
                        shared_xaxes=True)
    
    df = pd.DataFrame(
        {
            "date": [pd.Timestamp("20200102"), pd.Timestamp("20200103")],
            "rate": [0.88, 0.96],
            "samples": [130, 1200]
        }
    )
    
    fig.add_trace(go.Bar(x=df['date'], y=df['rate'], name = 'rate'), row = 1, col = 1)
    
    fig.update_layout(bargap=0,
                      bargroupgap = 0,
                     )
    
    fig.add_trace(go.Scatter(x=df['date'], y=df['samples'], name = 'samples'), row = 2, col = 1)
    fig.update_traces(marker_color = 'rgba(0,0,250, 0.3)',
                      marker_line_width = 0,
                      selector=dict(type="bar"))
    
    fig.show()