I'm having some problems rendering axis date correctly in python shiny with plotly surface plots. In particular, axis of type date are render as floats.
Find here an example in shiny playground.
Note that same exact code works if I render the figure with fig.show() outside of python shiny (i.e. the x axis renders as a date not as a float).
I Already tried to explicitly cast the layout of the figure as
fig.update_layout(
scene=dict(
xaxis=dict(
type="date",
tickformat='%Y'
)
)
But I get even worst result (figure not rendering at all).
Replace the @render_widget
with express.render.ui
and replace the return fig
with return ui.HTML(fig.to_html())
(see express.ui.HTML
and plotly.io.to_html
).
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from shiny.express import render, ui
def plotSurface(plot_dfs: list, names: list, title: str, **kwargs):
"""helper function to plot multiple surface with ESM color themes.
Args:
plot_dfs (list): list of dataframes containing the matrix. index of each df is datetime format and columns are maturities
names (list): list of names for the surfaces in plot_dfs
title (str): title of the plot
Raises:
TypeError: _description_
TypeError: _description_
ValueError: _description_
Returns:
Figure: plotly figure
"""
for i, plot_df in enumerate(plot_dfs):
if not isinstance(plot_df.index, pd.core.indexes.datetimes.DatetimeIndex):
raise TypeError(f"plot_df number {i} in plot_dfs should have an index of type datetime but got {type(plot_df.index)}")
if not (isinstance(plot_dfs, list) and isinstance(names, list)):
raise TypeError(f"both plot_dfs and names should be list. Instead got {type(plot_dfs), {type(names)}}")
if len(plot_dfs) != len(names):
raise ValueError(f"plot_dfs and names should have the same length but got {len(plot_dfs)} != {len(names)}")
fig = go.Figure()
# stack surfaces. The last one will overwrite the first one when values are equal
for i, (plot_df, name) in enumerate(zip(plot_dfs, names)):
X, Y = np.meshgrid(plot_df.index, plot_df.columns)
Z = plot_df.values.T
fig.add_trace(go.Surface(z=Z, x=X, y=Y, name=name, showscale=False, showlegend=True, opacity=0.9))
# Update layout for better visualization and custom template
fig.update_layout(
title=title,
title_x=0.5,
scene=dict(
xaxis_title='Date',
yaxis_title='Maturity',
zaxis_title='Value',
),
margin=dict(l=30, r=30, b=30, t=50),
# template=esm_theme,
legend=dict(title="Legend"),
)
return fig
@render.ui
def plot_1():
plot_dfs = [
pd.DataFrame(
index = pd.to_datetime([f"{y}/01/01" for y in range(2020, 2100)]),
columns = ["3m", "6m", "9m"] + [f"{y}Y" for y in range(1,31)],
data = 1
)
]
fig = plotSurface(plot_dfs, names=["t"], title=" ")
return ui.HTML(fig.to_html())