pythonplotlyplotly-dashsankey-diagram

Remove unused node in python plotly


What im trying here is to create a relationship between Tasks. Some of them are connected directly to each other while others are passing through this big box i circled instead of connecting directly(which is what i need). How can i remove this node?

def generate_links_and_nodes(dataframe):
        cleaned_links = []
        for _, row in dataframe.iterrows():
            q10_tasks = set(row['q10'].split(', '))
            q3_tasks = set(row['q3'].split(', '))
            q11_tasks = set(row['q11'].split(', '))

            # Create links between q10 and q3
            for q10 in q10_tasks:
                for q3 in q3_tasks:
                    if q10 != q3:
                        cleaned_links.append((q10, q3))

            # Create links between q3 and q11
            for q3 in q3_tasks:
                for q11 in q11_tasks:
                    if q3 != q11:
                        cleaned_links.append((q3, q11))

        # DataFrame from links
        links_df = pd.DataFrame(cleaned_links, columns=["source", "target"])

        # Collect unique nodes
        unique_nodes = sorted(set(pd.concat([links_df['source'], links_df['target']])))
        node_indices = {node: i for i, node in enumerate(unique_nodes)}

        # Map sources and targets to node indices
        sources = links_df['source'].map(node_indices).tolist()
        targets = links_df['target'].map(node_indices).tolist()
        values = [1] * len(links_df)  # Default weight of 1 for each link

        return sources, targets, values, unique_nodes

    # Generate the Sankey diagram inputs
    sources, targets, values, nodes = generate_links_and_nodes(df)

    # Create the Sankey diagram
    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=25,
            thickness=70,
            line=dict(color="black", width=0.5),
            label=nodes  # Only sub-tasks are shown
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values
        )
    )])

enter image description here

Sample data of query results . these are the results of my database when df_q10 = pd.read_sql_query(query_q10, conn) df_q3 = pd.read_sql_query(query_q3, conn) df_q11 = pd.read_sql_query(query_q11, conn) taking place

                  q3
0               T4.2
1   T4.2, T4.3, T4.4
2               T2.3
3               T2.2
4               T6.3
5               T6.3
6               T6.3
7         T4.1, T4.2
8               T1.3
9               T1.2
10              T1.3
11              T1.3
12              T7.3
13              T2.3
14              T2.1
                             q10
0
1
2
3
4                           T6.2
5                           T6.2
6
7   T1.1, T3.1, T3.2, T4.4, T5.1
8
9
10
11
12                          T7.1
13        T2.1, T2.2, T2.4, T3.2
14
                       q11
0
1   T1.1, T1.3, T3.1, T3.2
2
3
4
5
6
7   T1.1, T1.3, T3.1, T3.2
8
9
10
11
12                    T7.2
13
14

Solution

  • To eliminate the intermediary nodes you need to identify nodes acting as unnecessary passthroughs and bypassing them to create direct links between relevant tasks. So, in my example , intermediary nodes from the q3 column are identified as those that connect q10, the starting tasks, to q11, the ending tasks, and that add no context or relationships. I flag these as intermediary and the links passing through them are replaced by direct connections between the corresponding q10 and q11 nodes. I post the necessary addition to your code as well as plots for your way (with intermediary links) and without:

    import pandas as pd
    import plotly.graph_objects as go
    
    data = {
        "q10": ["A, B", "C, D", "E, F"],
        "q3": ["X", "Y, X", "Z"],
        "q11": ["G, H", "I, J", "K"]
    }
    df = pd.DataFrame(data)
    
    def generate_links_and_nodes(dataframe, remove_intermediates=True):
        cleaned_links = []
        for _, row in dataframe.iterrows():
            q10_tasks = set(row['q10'].split(', '))
            q3_tasks = set(row['q3'].split(', '))
            q11_tasks = set(row['q11'].split(', '))
    
            for q10 in q10_tasks:
                for q3 in q3_tasks:
                    cleaned_links.append((q10, q3))
    
            for q3 in q3_tasks:
                for q11 in q11_tasks:
                    cleaned_links.append((q3, q11))
    
        if remove_intermediates:
            direct_links = []
            intermediates = set(task for _, row in dataframe.iterrows() for task in row['q3'].split(', '))
            for source, target in cleaned_links:
                if source in intermediates and target in intermediates:
                    continue
                if source in intermediates:
                    for q10_task in row['q10'].split(', '):
                        for q11_task in row['q11'].split(', '):
                            direct_links.append((q10_task.strip(), q11_task.strip()))
                else:
                    direct_links.append((source, target))
            cleaned_links = direct_links
    
        links_df = pd.DataFrame(cleaned_links, columns=["source", "target"])
    
        unique_nodes = sorted(set(pd.concat([links_df['source'], links_df['target']])))
        node_indices = {node: i for i, node in enumerate(unique_nodes)}
    
        sources = links_df['source'].map(node_indices).tolist()
        targets = links_df['target'].map(node_indices).tolist()
        values = [1] * len(links_df)  
    
        return sources, targets, values, unique_nodes
    
    sources_with, targets_with, values_with, nodes_with = generate_links_and_nodes(df, remove_intermediates=False)
    sources_without, targets_without, values_without, nodes_without = generate_links_and_nodes(df, remove_intermediates=True)
    
    fig_with = go.Figure(data=[go.Sankey(
        node=dict(
            pad=25,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=nodes_with
        ),
        link=dict(
            source=sources_with,
            target=targets_with,
            value=values_with
        )
    )])
    fig_with.update_layout(title_text="With Intermediate Nodes", font_size=10)
    fig_with.show()
    
    fig_without = go.Figure(data=[go.Sankey(
        node=dict(
            pad=25,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=nodes_without
        ),
        link=dict(
            source=sources_without,
            target=targets_without,
            value=values_without
        )
    )])
    fig_without.update_layout(title_text="Without Intermediate Nodes", font_size=10)
    fig_without.show()
    
    

    Which gives

    enter image description here

    and

    enter image description here

    Edit: With your posted data

    This, I think is applicatble to your data:

    import pandas as pd
    import plotly.graph_objects as go
    
    data = {
        "q10": ["", "", "", "", "T6.2", "T6.2", "", "T1.1, T3.1, T3.2, T4.4, T5.1", "", "", "", "", "T7.1", "T2.1, T2.2, T2.4, T3.2", ""],
        "q3": ["T4.2", "T4.2, T4.3, T4.4", "T2.3", "T2.2", "T6.3", "T6.3", "T6.3", "T4.1, T4.2", "T1.3", "T1.2", "T1.3", "T1.3", "T7.3", "T2.3", "T2.1"],
        "q11": ["", "T1.1, T1.3, T3.1, T3.2", "", "", "", "", "", "T1.1, T1.3, T3.1, T3.2", "", "", "", "", "T7.2", "", ""]
    }
    df = pd.DataFrame(data)
    
    def generate_links_and_nodes(dataframe, remove_intermediates=True):
        cleaned_links = []
        for _, row in dataframe.iterrows():
            q10_tasks = set(row['q10'].split(', ')) if row['q10'] else set()
            q3_tasks = set(row['q3'].split(', ')) if row['q3'] else set()
            q11_tasks = set(row['q11'].split(', ')) if row['q11'] else set()
    
            for q10 in q10_tasks:
                for q3 in q3_tasks:
                    cleaned_links.append((q10, q3))
    
            for q3 in q3_tasks:
                for q11 in q11_tasks:
                    cleaned_links.append((q3, q11))
    
        if remove_intermediates:
            direct_links = []
            intermediates = set(task for _, row in dataframe.iterrows() for task in row['q3'].split(', ') if row['q3'])
            for source, target in cleaned_links:
                if source in intermediates and target in intermediates:
                    continue
                if source in intermediates:
                    for q10_task in dataframe[dataframe['q3'].str.contains(source, na=False)]['q10']:
                        for q11_task in dataframe[dataframe['q3'].str.contains(source, na=False)]['q11']:
                            if q10_task and q11_task:
                                for t10 in q10_task.split(', '):
                                    for t11 in q11_task.split(', '):
                                        direct_links.append((t10.strip(), t11.strip()))
                else:
                    direct_links.append((source, target))
            cleaned_links = direct_links
    
        links_df = pd.DataFrame(cleaned_links, columns=["source", "target"])
    
        unique_nodes = sorted(set(pd.concat([links_df['source'], links_df['target']])))
        node_indices = {node: i for i, node in enumerate(unique_nodes)}
    
        sources = links_df['source'].map(node_indices).tolist()
        targets = links_df['target'].map(node_indices).tolist()
        values = [1] * len(links_df)  
    
        return sources, targets, values, unique_nodes
    
    sources_with, targets_with, values_with, nodes_with = generate_links_and_nodes(df, remove_intermediates=False)
    sources_without, targets_without, values_without, nodes_without = generate_links_and_nodes(df, remove_intermediates=True)
    
    fig_with = go.Figure(data=[go.Sankey(
        node=dict(
            pad=25,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=nodes_with
        ),
        link=dict(
            source=sources_with,
            target=targets_with,
            value=values_with
        )
    )])
    fig_with.update_layout(title_text="With Intermediate Nodes", font_size=10)
    fig_with.show()
    
    fig_without = go.Figure(data=[go.Sankey(
        node=dict(
            pad=25,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=nodes_without
        ),
        link=dict(
            source=sources_without,
            target=targets_without,
            value=values_without
        )
    )])
    fig_without.update_layout(title_text="Without Intermediate Nodes", font_size=10)
    fig_without.show()
    

    which gives:

    enter image description here enter image description here