I have a pivoted dataframe and want to plot a heatmap in Plotly. I have a total column, making it hard to use a color scale.
import numpy as np
import pandas as pd
import plotly.express as px
city_origin = ['London', 'Tokio', 'Seoul', 'Paris', 'Tashkent', 'Washington', 'Moscow']
city_current = ['London', 'Madrid', 'Tashkent', 'Seoul', 'Paris', 'Toronto', 'Washington', 'Istanbul', 'Hanoi', 'Manilla', 'Delhi', 'Busan', 'Moscow']
migrant_origin = np.random.choice(city_origin,size = 1000)
migrant_current = np.random.choice(city_current, size = 1000)
migrant_salary = np.random.randint(1300, 6900, size = 1000)
df = pd.DataFrame({'migrant_origin':migrant_origin, 'migrant_current':migrant_current, 'migrant_salary':migrant_salary})
new_df = pd.pivot_table(df, index = 'migrant_origin', columns = 'migrant_current', values ='migrant_salary', aggfunc = 'sum', fill_value = 0)
new_df['total'] = new_df.sum(axis = 1)
cols = ['total'] + [col for col in new_df.columns if col != 'total']
new_df = new_df[cols]
px.imshow(new_df,text_auto=True)
I want to show the totals but not affected colors on the heatmap. Is there any way to exclude columns from coloring?
You can make a copy of the DataFrame with the values, and set all values in the total column to None
so a color doesn't show up. Then make another copy of the DataFrame with the original values as text, and display this text.
The only tricky part of this is that you have to use go.Heatmap
instead of px.imshow, and pass the value and text DataFrames as numpy arrays (and reverse the order of rows in the DataFrame, since go.Heatmap uses the first row of your array as the last row of the heatmap).
Since the background of the plotly base figure will be visible (through the transparent total column), I removed the gridlines as well. You can adjust the background color or theme if you want the total column to have a different color.
Fully reproducible example:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
city_origin = ['London', 'Tokio', 'Seoul', 'Paris', 'Tashkent', 'Washington', 'Moscow']
city_current = ['London', 'Madrid', 'Tashkent', 'Seoul', 'Paris', 'Toronto', 'Washington', 'Istanbul', 'Hanoi', 'Manilla', 'Delhi', 'Busan', 'Moscow']
np.random.seed(42)
migrant_origin = np.random.choice(city_origin,size = 1000)
migrant_current = np.random.choice(city_current, size = 1000)
migrant_salary = np.random.randint(1300, 6900, size = 1000)
df = pd.DataFrame({'migrant_origin':migrant_origin, 'migrant_current':migrant_current, 'migrant_salary':migrant_salary})
new_df = pd.pivot_table(df, index = 'migrant_origin', columns = 'migrant_current', values ='migrant_salary', aggfunc = 'sum', fill_value = 0)
new_df['total'] = new_df.sum(axis = 1)
cols = ['total'] + [col for col in new_df.columns if col != 'total']
non_total_cols = [col for col in new_df.columns if col != 'total']
new_df = new_df[cols]
new_df_values = new_df.iloc[::-1].copy()
new_df_values['total'] = None
new_df_text = new_df.iloc[::-1].astype(str).copy()
## separate out text and values into numpy arrays
values_array = new_df_values.to_numpy()
text_array = new_df_text.to_numpy()
y_array = new_df_text.index.values
x_array = new_df_text.columns.values
fig = go.Figure()
fig.add_trace(go.Heatmap(
z=values_array,
y=y_array,
x=x_array,
text=text_array,
texttemplate="%{text:.2s}",
))
fig.update_layout(
xaxis=dict(showgrid=False),
yaxis=dict(showgrid=False),
)
fig.show()