I want to plot 2 vertical lines displaying two different means. This is again separated into different subplots using plotly facet grid.
With below, I'm displaying the cut
in separate subplots. In each subplot, I have two different colors signifying val
. For each subplot, I want to display the means for val and cut. But I'm only displaying the mean for cut at the moment.
import seaborn as sns
import plotly.express as px
diamonds = sns.load_dataset('diamonds')
diamonds['val'] = np.random.randint(1, 3, diamonds.shape[0])
grpval = diamonds.groupby(['cut','val'])['price'].mean()
print(grpval)
fig = px.histogram(data_frame=diamonds,
x='price',
facet_col='cut',
color = "val",
facet_col_wrap = 2,
)
for c,idx in zip(diamonds['cut'].unique(),[(1,1),(1,2),(2,1),(2,2),(3,1)]):
df = diamonds[diamonds['cut'] == c]
fig.add_vline(x=df['price'].tail(1).values[0], line_width=1, line_dash='solid', line_color='red', row=idx[0], col=idx[1])
fig.show()
I tried plotting both means but am only getting one.
Here's how you could add two vertical mean lines to the plot:
import numpy as np
import seaborn as sns
import plotly.express as px
diamonds = sns.load_dataset("diamonds")
diamonds["val"] = np.random.randint(1, 3, diamonds.shape[0])
grpval = diamonds.groupby(["cut", "val"])["price"].mean()
# reshape the mean values into cut and val dimensions
grpval_per_cut = np.reshape(grpval.values, (diamonds["cut"].unique().size, -1))
# add color sequence to have same colors on histogram and lines
color_discrete_sequence = px.colors.qualitative.G10
fig = px.histogram(
data_frame=diamonds,
x="price",
facet_col="cut",
color="val",
facet_col_wrap=2,
color_discrete_sequence=color_discrete_sequence,
)
# iterate over mean values and subplots (I could not figure out a direct way
# to associate the subplot with the correct mean value, so I arranged it "by hand"
for mean_per_val, (row, col) in zip(
grpval_per_cut, [(3, 1), (3, 2), (2, 2), (2, 1), (1, 1)]
):
for val, line_color in zip(mean_per_val, color_discrete_sequence):
fig.add_vline(
x=val,
line_width=1,
line_dash="solid",
line_color=line_color,
row=row,
col=col,
)
fig.show()