pythoncolorscluster-analysisplotly-dashdropdownbox

How to dynamically change color of selected category using dropdown box?


I am working on an app that takes in 2 inputs to update a scatterplot displaying the results of a cluster analysis. The first input filters the points on the graph through a time range slider. The second input is a dropdown box that is intended to highlight the color of a category of interest on the graph. The categories available in the dropdown box are the different clusters resulting from cluster analysis done on the time range of data. Depending on the time range selected, there may be a different number of clusters available.

I want to be able to color the categories not selected in various continuous shades of grey and the selected category to be colored green or something vibrant that stands out. As of now, I have a combination of

color_discrete_map

and

color_continuous_scale

to meet this requirement, but it does not seem to be working. The color_discrete_map argument seems to be overruled by the color_continuous_scale argument.

Here is my code so far:

#Import packages
import pandas as pd
import numpy as np
import os
import plotly.express as px
import dash
from dash import dcc, html
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output
from sklearn.cluster import KMeans
from sklearn.impute import KNNImputer
from kneed import KneeLocator
from sklearn.preprocessing import StandardScaler
import random

cluster_choices = ["Cluster 1","Cluster 2","Cluster 3","Cluster 4","Cluster 5"]

options1 = ['Alabama','Wyoming','California','Tennessee','Texas']
options2 = random.sample(range(1991,2020),20)
options3 = random.sample(range(1000),20)
options4 = random.sample(range(750),20)
options5 = random.sample(range(500),20)

list1 = np.random.choice(options1, size=250, replace=True).tolist()
list2 = np.random.choice(options2, size=250, replace=True).tolist()
list3 = np.random.choice(options3, size=250, replace=True).tolist()
list4 = np.random.choice(options4, size=250, replace=True).tolist()
list5 = np.random.choice(options5, size=250, replace=True).tolist()



df = pd.DataFrame(list(zip(list1, list2, list3, list4, list5)),
               columns =['State', 'Year','Metric1','Metric2','Metric3'])



app = dash.Dash(__name__,assets_folder=os.path.join(os.curdir,"assets"))
server = app.server
app.layout = html.Div([
    dcc.Tabs([
        dcc.Tab(label='Dashboard',
            children=[
                dbc.Row([
                    dbc.Col([
                        dcc.RangeSlider(
                                    id='range_slider',
                                    min=1991,
                                    max=2020,
                                    step=1,
                                    value=[1991, 2020],
                                    allowCross=False,
                                    pushable=2,
                                    tooltip={"placement": "bottom", "always_visible": True}
                        )
                    ],width=6),
                    dbc.Col([
                        dcc.Dropdown(
                            id='dropdown1',
                            options=[{'label': i, 'value': i} for i in cluster_choices],
                            value=cluster_choices[0],
                        )
                    ],width=6)
                ]),
                 dbc.Row([
                    dbc.Col([
                        dcc.Graph(id='cluster_graph')
                    ],width=12)
             
                ])
            ]
        )

    ])
])

#Configure reactivity of cluster map controlled by range slider
@app.callback(
    Output('cluster_graph', 'figure'), 
    Output('dropdown1', 'options'),
    Input('range_slider', 'value'),
    Input('dropdown1','value')
) 

def update_cluster_graph(slider_range_values,dd1):
    filtered = df[(df['Year']>=slider_range_values[0]) & (df['Year']<=slider_range_values[1])]

    X = filtered

    #Step 1.) Break out into state and non-state dfs
    states = pd.DataFrame(X[['State']])
    not_states = X.loc[:, ~X.columns.isin(['State'])]

    #Step 2.) Impute the non-text columns
    imputer = KNNImputer(n_neighbors=5)
    not_states_fixed = pd.DataFrame(imputer.fit_transform(not_states),columns=not_states.columns)

    #Step 3.) Perform clustering
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(not_states_fixed)

    #Step 4.) Define the kmeans function with initialization as k-means++
    kmeans = KMeans(n_clusters=3, init='k-means++')

    #Step 5.) Fit the k means algorithm on scaled data
    kmeans.fit(data_scaled)

    SSE = []
    for cluster in range(1,10):
        kmeans = KMeans(n_clusters = cluster, init='k-means++')
        kmeans.fit(data_scaled)
        SSE.append(kmeans.inertia_)
        
    kl = KneeLocator(
        range(1, 10), SSE, curve="convex", direction="decreasing"
    )
    #Step 6.) Identify # of clusters
    elbow = kl.elbow

    kmeans = KMeans(n_clusters = elbow, init='k-means++')
    kmeans.fit(data_scaled)
    pred = kmeans.predict(data_scaled)

    #Step 7.) Get clusters back in original df
    frame = pd.DataFrame(data_scaled)
    frame['cluster'] = pred
    frame['cluster'].value_counts()

    clusters = frame['cluster'] +1

    not_states_fixed = not_states_fixed.dropna()
    not_states_fixed['cluster'] = clusters.values
    not_states_fixed['cluster'] = not_states_fixed['cluster'].astype('str')

    state_list = states['State'].values.tolist()

    not_states_fixed['State'] = state_list

    X = not_states_fixed

    
    #This is the filtered list that gets populated in the dropdown box
    cluster_list = X['cluster'].unique().tolist()
    cluster_list.sort()
    label = 'Cluster '
    new_cluster_list = [label + x for x in cluster_list]

    sortedX = X.sort_values(by='cluster',ascending=True)
    sortedX['cluster_num'] = sortedX['cluster'].astype(int)
    sortedX['cluster'] = 'Cluster '+sortedX['cluster']

    fig = px.scatter(
        sortedX,
        x="Metric1", 
        y="Metric2", 
        color="cluster_num",
        color_discrete_map={
                f"{dd1}": "green"
        },
        color_continuous_scale="Greys",
        hover_data = {
            "State":True,
            "Year":True,
            "Metric1":True,
            "Metric2":True,
            "Metric3":True
        },
        template='plotly_dark'

    )
    fig.update_traces(marker=dict(size=10,
                              line=dict(width=0.5,
                                        color='white')),
                  selector=dict(mode='markers'))
           
    return fig, [{'label':i,'value':i} for i in new_cluster_list]
    
if __name__=='__main__':
    app.run_server()

How can I fix the issues mentioned above?


Solution

  • You have three different ways of changing the colors:

    color="cluster_num",
    color_discrete_map={
            f"{dd1}": "green"
    },
    color_continuous_scale="Greys",
    

    And they contradict each other. First of all remove "color_continuous_scale" and "color_discrete_map" entirely. You seem to want custom defined colors, not predefined colors.

    From my experience the best way to do this is to add a separate column that has a color defined for each row of the data. This way you have complete control of the color logic.

    Create a column that holds the color value for each datapoint. This can be a hexadecimal value or RGBA (if you also want to control opacity). A pandas lambda function is a simple way of doing this:

    # adjust color threshold as needed
    def assign_color(value):
        float_value = float(value)
        if float_value > 2:
            return "rgba(250, 0.0, 0.0, 1.0)"
        elif float_value > 0 and float_value < 2:
            return "rgba(250, 138, 5, 1.0)"
        else:
            return "rgba(236, 236, 236, 0.6)"  # "lightgray" 
    
    sortedX['color'] = result.apply(lambda x: assign_color(x["data_column"]), axis = 1)
    

    Then, you should just be able to assign the color to the value of the column:

    fig = px.scatter(
        sortedX,
        x="Metric1", 
        y="Metric2", 
        color="color",
        ...
    )
    

    Edit: If that does not work then its probably because of your marker definition. Remove the color="color" line from above and update your marker definition (in fig.update_traces):

    marker=dict(color = sortedX["color"],
                size = 4)