pythonnetworkx

Create a graph using the edge attribute as node


I have a directed graph where the edges have the attribute edge_id. I want to create a new graph using the edge_id as nodes.

I think there should be some more straightforward method than this?

import networkx as nx
import matplotlib.pyplot as plt

edges = [("A","D", {"edge_id":1}),
         ("B","D", {"edge_id":2}),
         ("D", "G", {"edge_id":3}),
         ("C", "F", {"edge_id":4}),
         ("E", "F", {"edge_id":5}),
         ("F", "G", {"edge_id":6}),
         ("G", "I", {"edge_id":7}),
         ("H", "I", {"edge_id":8}),
         ("I", "J", {"edge_id":9}),
         ]

G = nx.DiGraph()
G.add_edges_from(edges)

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
pos = nx.spring_layout(G)
nx.draw(G, with_labels=True, pos=pos, ax=ax[0])

end_node = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1][0]
start_nodes = [n for n, d in G.in_degree() if d == 0]

H = nx.DiGraph()
paths = []
#Iterate over each start node and find the path from it to the end node
for start_node in start_nodes:
    my_list = []
    path = nx.shortest_path(G, source=start_node, target=end_node)
    for n1, n2 in zip(path, path[1:]):
        my_list.append(G.edges[(n1, n2)]["edge_id"])
    paths.append(my_list)
#paths
#[[1, 3, 7, 9], [2, 3, 7, 9], [4, 6, 7, 9], [5, 6, 7, 9], [8, 9]]

for sublist in paths:
    for n1, n2 in zip(sublist, sublist[1:]):
        H.add_edge(n1, n2)
nx.draw(H, with_labels=True, pos=nx.spring_layout(H), ax=ax[1])

enter image description here


Solution

  • The core of this is just to remap your existing edges/nodes to new ones based on your additional data.

    def remap_edges(edges):
        new_nodes = {n: e["edge_id"] for n, _, e in edges}
        return [
            (new_nodes[n1], new_nodes[n2])
            for n1, n2, _
            in edges
            if n1 in new_nodes and n2 in new_nodes
        ]
    

    Will do that for you. At that point you might do:

    import networkx as nx
    import matplotlib.pyplot as plt
    
    def remap_edges(edges):
        new_nodes = {n: e["edge_id"] for n, _, e in edges}
        return [(new_nodes[n1], new_nodes[n2]) for n1, n2, _ in edges if n1 in new_nodes and n2 in new_nodes]
    
    def build_chart(edges, ax):
        G = nx.DiGraph()
        G.add_edges_from(edges)
        nx.draw(G, with_labels=True, pos=nx.spring_layout(G), ax=ax)
    
    edges = [
        ("A","D", {"edge_id":1}),
        ("B","D", {"edge_id":2}),
        ("D", "G", {"edge_id":3}),
        ("C", "F", {"edge_id":4}),
        ("E", "F", {"edge_id":5}),
        ("F", "G", {"edge_id":6}),
        ("G", "I", {"edge_id":7}),
        ("H", "I", {"edge_id":8}),
        ("I", "J", {"edge_id":9}),
    ]
    
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
    build_chart(edges, ax[0])
    build_chart(remap_edges(edges), ax[1])
    plt.show()
    

    That should reproduce your chart image.