pythonnetworkx

Create a graph using the edge attribute as node


I have a directed graph where the edges have the attribute edge_id. I want to create a new graph using the edge_id as nodes.

I think there should be some more straightforward method than this?

import networkx as nx
import matplotlib.pyplot as plt

edges = [("A","D", {"edge_id":1}),
         ("B","D", {"edge_id":2}),
         ("D", "G", {"edge_id":3}),
         ("C", "F", {"edge_id":4}),
         ("E", "F", {"edge_id":5}),
         ("F", "G", {"edge_id":6}),
         ("G", "I", {"edge_id":7}),
         ("H", "I", {"edge_id":8}),
         ("I", "J", {"edge_id":9}),
         ]

G = nx.DiGraph()
G.add_edges_from(edges)

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
pos = nx.spring_layout(G)
nx.draw(G, with_labels=True, pos=pos, ax=ax[0])

end_node = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1][0]
start_nodes = [n for n, d in G.in_degree() if d == 0]

H = nx.DiGraph()
paths = []
#Iterate over each start node and find the path from it to the end node
for start_node in start_nodes:
    my_list = []
    path = nx.shortest_path(G, source=start_node, target=end_node)
    for n1, n2 in zip(path, path[1:]):
        my_list.append(G.edges[(n1, n2)]["edge_id"])
    paths.append(my_list)
#paths
#[[1, 3, 7, 9], [2, 3, 7, 9], [4, 6, 7, 9], [5, 6, 7, 9], [8, 9]]

for sublist in paths:
    for n1, n2 in zip(sublist, sublist[1:]):
        H.add_edge(n1, n2)
nx.draw(H, with_labels=True, pos=nx.spring_layout(H), ax=ax[1])

enter image description here


Solution

  • This approach processes all the edges in one go, skipping any path calculations. Instead, it focuses on how edges are connected (who's coming in and who's going out). This makes it much faster for larger or more complex graphs since it avoids doing the same work multiple times.

    import networkx as nx
    import matplotlib.pyplot as plt
    
    edges = [("A", "D", {"edge_id": 1}),
             ("B", "D", {"edge_id": 2}),
             ("D", "G", {"edge_id": 3}),
             ("C", "F", {"edge_id": 4}),
             ("E", "F", {"edge_id": 5}),
             ("F", "G", {"edge_id": 6}),
             ("G", "I", {"edge_id": 7}),
             ("H", "I", {"edge_id": 8}),
             ("I", "J", {"edge_id": 9})]
    
    G = nx.DiGraph()
    G.add_edges_from(edges)
    
    H = nx.DiGraph()
    
    edge_id_to_node = {}
    for u, v, data in G.edges(data=True):
        edge_id = data["edge_id"]
        H.add_node(edge_id)
        edge_id_to_node[(u, v)] = edge_id
    
    for u, v in G.edges():
        for v_next in G.successors(v):
            if (v, v_next) in edge_id_to_node:
                H.add_edge(edge_id_to_node[(u, v)], edge_id_to_node[(v, v_next)])
    
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
    
    pos = nx.spring_layout(G)
    nx.draw(G, pos=pos, with_labels=True, node_color='lightblue', edge_color='gray', ax=ax[0])
    ax[0].set_title("Original Graph (G)")
    
    pos_H = nx.spring_layout(H)
    nx.draw(H, pos=pos_H, with_labels=True, node_color='lightgreen', edge_color='gray', ax=ax[1])
    ax[1].set_title("Transformed Graph (H - edge_id as nodes)")
    
    plt.tight_layout()
    plt.show()