I have a DataFrame like this that I'm trying to describe with a Sankey diagram:
import pandas as pd
pd.DataFrame({
'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
'sex': ['male', 'female', 'female', 'male', 'male'],
'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
'count': [8, 10, 11, 14, 6]
})
animal sex status count
0 dog male wild 8
1 cat female domesticated 10
2 cat female domesticated 11
3 dog male wild 14
4 cat male domesticated 6
I'm trying to follow the steps in the documentation but I can't make it work - I can't understand what branches where. Here's the example code:
import plotly.graph_objects as go
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 15,
thickness = 20,
line = dict(color = "black", width = 0.5),
label = ["A1", "A2", "B1", "B2", "C1", "C2"],
color = "blue"
),
link = dict(
source = [0, 1, 0, 2, 3, 3],
target = [2, 3, 3, 4, 4, 5],
value = [8, 4, 2, 8, 4, 2]
))])
fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()
You can create with Plotly a Sankey diagram in the following way:
import pandas as pd
import plotly.graph_objects as go
label_list = ['cat', 'dog', 'domesticated', 'female', 'male', 'wild']
# cat: 0, dog: 1, domesticated: 2, female: 3, male: 4, wild: 5
source = [0, 0, 1, 3, 4, 4]
target = [3, 4, 4, 2, 2, 5]
count = [21, 6, 22, 21, 6, 22]
fig = go.Figure(data=[go.Sankey(
node = {"label": label_list},
link = {"source": source, "target": target, "value": count}
)])
fig.show()
How does it work: The lists source
, target
and count
have all length 6 and the Sankey diagram has 6 arrows. The elements of source
and target
are the indexes of label_list
. So the the first element of source is 0 which means "cat". The first element of target is 3 which means "female". The first element of count is 21. Therefore, the first arrow of the diagram goes from cat to female and has size 21. Correspondingly, the second elements of the lists source, target and count define the second arrow, etc.
Possibly you want to create a bigger Sankey diagram as in this example. Defining the source, target and count list manually then becomes very tedious. So here's a code which creates these lists from a dataframe of your format.
import pandas as pd
import numpy as np
df = pd.DataFrame({
'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
'sex': ['male', 'female', 'female', 'male', 'male'],
'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
'count': [8, 10, 11, 14, 6]
})
categories = ['animal', 'sex', 'status']
newDf = pd.DataFrame()
for i in range(len(categories)-1):
tempDf = df[[categories[i],categories[i+1],'count']]
tempDf.columns = ['source','target','count']
newDf = pd.concat([newDf,tempDf])
newDf = newDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
label_list = list(np.unique(df[categories].values))
source = newDf['source'].apply(lambda x: label_list.index(x))
target = newDf['target'].apply(lambda x: label_list.index(x))
count = newDf['count']