pythonpandasplotlyplotly-pythonsankey-diagram

How do I make a simple, multi-level Sankey diagram with Plotly?


I have a DataFrame like this that I'm trying to describe with a Sankey diagram:

import pandas as pd

pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})
    animal  sex     status          count
0   dog     male    wild            8
1   cat     female  domesticated    10
2   cat     female  domesticated    11
3   dog     male    wild            14
4   cat     male    domesticated    6

I'm trying to follow the steps in the documentation but I can't make it work - I can't understand what branches where. Here's the example code:

import plotly.graph_objects as go

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["A1", "A2", "B1", "B2", "C1", "C2"],
      color = "blue"
    ),
    link = dict(
      source = [0, 1, 0, 2, 3, 3], 
      target = [2, 3, 3, 4, 4, 5],
      value = [8, 4, 2, 8, 4, 2]
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

Here's what I'm trying to achieve: enter image description here


Solution

  • You can create with Plotly a Sankey diagram in the following way:

    import pandas as pd
    import plotly.graph_objects as go
    
    label_list = ['cat', 'dog', 'domesticated', 'female', 'male', 'wild']
    # cat: 0, dog: 1, domesticated: 2, female: 3, male: 4, wild: 5
    source = [0, 0, 1, 3, 4, 4]
    target = [3, 4, 4, 2, 2, 5]
    count = [21, 6, 22, 21, 6, 22]
    
    fig = go.Figure(data=[go.Sankey(
        node = {"label": label_list},
        link = {"source": source, "target": target, "value": count}
        )])
    fig.show()
    

    sankey diagram How does it work: The lists source, target and count have all length 6 and the Sankey diagram has 6 arrows. The elements of source and target are the indexes of label_list. So the the first element of source is 0 which means "cat". The first element of target is 3 which means "female". The first element of count is 21. Therefore, the first arrow of the diagram goes from cat to female and has size 21. Correspondingly, the second elements of the lists source, target and count define the second arrow, etc.


    Possibly you want to create a bigger Sankey diagram as in this example. Defining the source, target and count list manually then becomes very tedious. So here's a code which creates these lists from a dataframe of your format.

    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame({
        'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
        'sex': ['male', 'female', 'female', 'male', 'male'],
        'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
        'count': [8, 10, 11, 14, 6]
    })
    
    categories = ['animal', 'sex', 'status']
    
    newDf = pd.DataFrame()
    for i in range(len(categories)-1):
        tempDf = df[[categories[i],categories[i+1],'count']]
        tempDf.columns = ['source','target','count']
        newDf = pd.concat([newDf,tempDf])    
    newDf = newDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
    
    label_list = list(np.unique(df[categories].values))
    source = newDf['source'].apply(lambda x: label_list.index(x))
    target = newDf['target'].apply(lambda x: label_list.index(x))
    count = newDf['count']