pythonplotlylegendcolorbar

How to get a discrete colour bar for plotly Scatter plot?


I have defined a list of 5 colours, which I use to colour the markers in the go.Scatter plot below.

I want the colour bar (legend) to show these discrete values, but the resulting colour bar stays white. See the white/empty bar titled "outputs" in the image.

enter image description here

Reproducible example

import numpy as np
import plotly.graph_objects as go

#Data for testing
outputs = np.random.choice([0, 1, 2, 3, 4], size=100)

#Figure
marker_palette=[
     'rgb(39, 100, 25)',   # ++
     'rgb(156, 207, 100)', # +
     'rgb(247, 247, 247)', # 0
     'rgb(232, 150, 196)', # -
     'rgb(142, 1, 82)',    # --
] 

scatter = go.Scatter(
    x=np.random.randn(100),
    y=np.random.randn(100),
    mode='markers',
    marker_color=[marker_palette[o] for o in outputs],
    marker=dict(
        # color=outputs, colorscale=marker_palette, #creates a continuous scale - not what I want
        colorbar={
            'title': 'outputs',
            'tickmode': 'array',
            'tickvals': np.unique(outputs),
            'ticktext': ['++', '+', '0', '-', '--']
        },
        #/colorbar
    ),
    #/marker
)
go.Figure(scatter)

I can get a continuous colour bar without any problems, but I want a discrete one.

I would prefer solutions using graph_objects as opposed to plotly.express, as I want to see how the internals get configured in your answer (though it's not a strict requirement).

How do I get a colour bar that shows the discrete colours?


Solution

  • Repeat each colour twice to define bands of constant colour and supply it as marker=dict(color=..., colorscale=...):

    enter image description here

    import numpy as np
    import plotly.graph_objects as go
    
    #Data for testing
    outputs = np.random.choice([0, 1, 2, 3, 4], size=100)
    
    #Figure
    marker_colorscale=[
         (0, 'rgb(39, 100, 25)'), (1/5, 'rgb(39, 100, 25)'),
         (1/5, 'rgb(156, 207, 100)'), (2/5, 'rgb(156, 207, 100)'),
         (2/5, 'rgb(247, 247, 247)'), (3/5, 'rgb(247, 247, 247)'),
         (3/5, 'rgb(232, 150, 196)'), (4/5, 'rgb(232, 150, 196)'),
         (4/5, 'rgb(142, 1, 82)'), (5/5, 'rgb(142, 1, 82)'),
    ] 
    
    scatter = go.Scatter(
        x=np.random.randn(100),
        y=np.random.randn(100),
        mode='markers',
        marker=dict(
            color=outputs,
            colorscale=marker_colorscale,
            colorbar={
                'title': 'outputs',
                'tickmode': 'array',
                'tickvals': np.unique(outputs),
                'ticktext': ['++', '+', '0', '-', '--']
            },
            #/colorbar
        ),
        #/marker
    )
    go.Figure(scatter)