I am trying to create a linked marginal plot above the original plot, with the same x axis but with a different y axis.
I've seen that in plotly.express
package there are 4 options in which you can create marginal_x plot on a scatter fig, but they are all based on the same columns as x and y.
In my case, I have a date on my x-axis and rate of something on my y-axis, and I am trying to produce a histogram marginal distribution plot of the samples of which this rate is based on (located in samples column within the df).
I'm simplifying what I've tried without lessening any important details:
import pandas as pd
import plotly.express as px
df = pd.DataFrame(
{
"date": [pd.Timestamp("20200102"), pd.Timestamp("20200103")],
"rate": [0.88, 0.96],
"samples": [130, 1200])
}
)
fig = px.scatter(df, x='date', y='rate', marginal_x='histogram')
fig.show()
The documentation I based on: https://plotly.com/python/marginal-plots/
My desired result: Example:
The difference is that I use an aggregated df, so my count is just 1, instead of being the amount of samples.
Any ideas?
Thanks!
I'm understanding your statement
[...] and rate of something on my y-axis
... to mean that you'd like to display a value on your histogram that is not count.
marginal_x='histogram'
in px.scatter()
seems to be defaulted to show counts only, meaning that there is no straight-forward way to show values of individual observations. But if you're willing to use fig = make_subplots()
in combination with go.Scatter()
and go.Bar()
, then you can easily build this:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from plotly.subplots import make_subplots
import plotly.graph_objects as go
fig = make_subplots(rows=2, cols=1,
row_heights=[0.2, 0.8],
vertical_spacing = 0.02,
shared_yaxes=False,
shared_xaxes=True)
df = pd.DataFrame(
{
"date": [pd.Timestamp("20200102"), pd.Timestamp("20200103")],
"rate": [0.88, 0.96],
"samples": [130, 1200]
}
)
fig.add_trace(go.Bar(x=df['date'], y=df['rate'], name = 'rate'), row = 1, col = 1)
fig.update_layout(bargap=0,
bargroupgap = 0,
)
fig.add_trace(go.Scatter(x=df['date'], y=df['samples'], name = 'samples'), row = 2, col = 1)
fig.update_traces(marker_color = 'rgba(0,0,250, 0.3)',
marker_line_width = 0,
selector=dict(type="bar"))
fig.show()