pythonpandasplotlyheatmap

How can I exclude a column from a heatmap?


I have a pivoted dataframe and want to plot a heatmap in Plotly. I have a total column, making it hard to use a color scale.

import numpy as np
import pandas as pd
import plotly.express as px


city_origin = ['London', 'Tokio', 'Seoul', 'Paris', 'Tashkent', 'Washington', 'Moscow']
city_current = ['London', 'Madrid', 'Tashkent', 'Seoul', 'Paris', 'Toronto', 'Washington', 'Istanbul', 'Hanoi', 'Manilla', 'Delhi', 'Busan', 'Moscow']

migrant_origin = np.random.choice(city_origin,size = 1000)
migrant_current = np.random.choice(city_current, size = 1000)
migrant_salary = np.random.randint(1300, 6900, size = 1000)
df = pd.DataFrame({'migrant_origin':migrant_origin, 'migrant_current':migrant_current, 'migrant_salary':migrant_salary})

new_df = pd.pivot_table(df, index = 'migrant_origin', columns = 'migrant_current', values ='migrant_salary', aggfunc = 'sum', fill_value = 0)

new_df['total'] = new_df.sum(axis = 1)
cols = ['total'] + [col for col in new_df.columns if col != 'total']
new_df = new_df[cols]

px.imshow(new_df,text_auto=True)

enter image description here

I want to show the totals but not affected colors on the heatmap. Is there any way to exclude columns from coloring?


Solution

  • You can make a copy of the DataFrame with the values, and set all values in the total column to None so a color doesn't show up. Then make another copy of the DataFrame with the original values as text, and display this text.

    The only tricky part of this is that you have to use go.Heatmap instead of px.imshow, and pass the value and text DataFrames as numpy arrays (and reverse the order of rows in the DataFrame, since go.Heatmap uses the first row of your array as the last row of the heatmap).

    Since the background of the plotly base figure will be visible (through the transparent total column), I removed the gridlines as well. You can adjust the background color or theme if you want the total column to have a different color.

    Fully reproducible example:

    import numpy as np
    import pandas as pd
    import plotly.express as px
    import plotly.graph_objects as go
    
    city_origin = ['London', 'Tokio', 'Seoul', 'Paris', 'Tashkent', 'Washington', 'Moscow']
    city_current = ['London', 'Madrid', 'Tashkent', 'Seoul', 'Paris', 'Toronto', 'Washington', 'Istanbul', 'Hanoi', 'Manilla', 'Delhi', 'Busan', 'Moscow']
    
    np.random.seed(42)
    migrant_origin = np.random.choice(city_origin,size = 1000)
    migrant_current = np.random.choice(city_current, size = 1000)
    migrant_salary = np.random.randint(1300, 6900, size = 1000)
    df = pd.DataFrame({'migrant_origin':migrant_origin, 'migrant_current':migrant_current, 'migrant_salary':migrant_salary})
    
    new_df = pd.pivot_table(df, index = 'migrant_origin', columns = 'migrant_current', values ='migrant_salary', aggfunc = 'sum', fill_value = 0)
    
    new_df['total'] = new_df.sum(axis = 1)
    cols = ['total'] + [col for col in new_df.columns if col != 'total']
    non_total_cols = [col for col in new_df.columns if col != 'total']
    
    new_df = new_df[cols]
    
    new_df_values = new_df.iloc[::-1].copy()
    new_df_values['total'] = None
    new_df_text = new_df.iloc[::-1].astype(str).copy()
    
    ## separate out text and values into numpy arrays
    values_array = new_df_values.to_numpy()
    text_array = new_df_text.to_numpy()
    y_array = new_df_text.index.values
    x_array = new_df_text.columns.values
    
    fig = go.Figure()
    
    fig.add_trace(go.Heatmap(
        z=values_array,
        y=y_array,
        x=x_array,
        text=text_array,
        texttemplate="%{text:.2s}",
    ))
    
    fig.update_layout(
        xaxis=dict(showgrid=False),
        yaxis=dict(showgrid=False),
    )
    
    fig.show()
    

    enter image description here