pythonplotlyplotly-python

Plotly express px.scatter, categorical variables to not overlap


I'm trying to compare the categorical assignments of the same items returned by two databases. I want to plot each database on a categorical axis, then use px.scatter to visualize the intersections between them.

The problem I'm running into is px.scatter doesn't seem to have an option to allow for jittering the datapoints so they are not all on top of each other. I found an option scattermode=group with scattergroup=[0,1], but it doesn't do anything in my case.

I am able to get what I want out of JMP. How can I replicate this in plotly express?

MWE

import pandas as pd
import plotly.express as px

d = {'Document_Type_x': ['Research Article', 'Research Article', 'Letter to the Editor', 'Letter to the Editor', 'Letter to the Editor'],
     'Document_Type_y': ['Article', 'Article', 'Letter', 'Letter', 'Letter']}
df = pd.DataFrame(data=d)

fig = px.scatter(df, x='Document_Type_x', y='Document_Type_y')

fig.update_layout(scattermode='group', scattergap=.9)
fig.update_xaxes(categoryorder = 'category ascending')
fig.update_yaxes(categoryorder = 'category ascending')

fig.show()

Using px.scatter, incorrectly stacking all data points px.scatter, incorrectly stacking all data points

Using JMP, correctly showing how many points at each intersection using JMP, correctly showing how many points at each intersection


Solution

  • I had never used this feature before, so I checked the reference again. In this example, it is possible with a color category, so I intentionally add that category. And again modify the legend, marker color, and hover template caused by the addition of the color category. It is a hacking approach, but I think it will get you what you are after.

    import pandas as pd
    import plotly.express as px
    
    d = {'Document_Type_x': ['Research Article', 'Research Article', 'Letter to the Editor', 'Letter to the Editor', 'Letter to the Editor'],
         'Document_Type_y': ['Article', 'Article', 'Letter', 'Letter', 'Letter']}
    df = pd.DataFrame(data=d)
    df['type'] = ['A','B','A','B','C'] #update
    
    fig = px.scatter(df, x='Document_Type_x', y='Document_Type_y', color='type')
    
    fig.update_traces(showlegend=False, marker=dict(color='blue')) #update
    fig.update_traces(hovertemplate='Document_Type_x: %{x}<br>Document_Type_y: %{y}<extra></extra>') # update
    
    fig.update_layout(scattermode='group', scattergap=0.9)
    fig.update_xaxes(categoryorder = 'category ascending')
    fig.update_yaxes(categoryorder = 'category ascending')
    
    fig.show()
    

    enter image description here