networkxvisualizationgraph-visualizationpytorch-geometric

How to visualize HeteroData pytorch geometric graph with any tool?


Hello what is a good way to visualize a pyg HeteroData object ? (defined similarly: https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html#creating-heterogeneous-gnns )

I tried with networkx but I think it is restricted to homogeneous graph ( it is possible to convert it but it is much less informative).

g = torch_geometric.utils.to_networkx(data.to_homogeneous(), to_undirected=False )

Did anyone try to do it with other python lib (matplotlib) or js (sigma.js/d3.js)?

Any docs link you can share?


Solution

  • You can do this with networkx, but you need to do some coding to tell it how to format the nodes and edges.

    # Simple example of network x rendering with colored nodes and edges
    import matplotlib.pyplot as plt
    import networkx as nx
    from torch_geometric.utils import to_networkx
    
    graph = to_networkx(data, to_undirected=False)
    
    # Define colors for nodes and edges
    node_type_colors = {
        "Station": "#4599C3",
        "Lot": "#ED8546",
    }
    
    node_colors = []
    labels = {}
    for node, attrs in graph.nodes(data=True):
        node_type = attrs["type"]
        color = node_type_colors[node_type]
        node_colors.append(color)
        if attrs["type"] == "Station":
            labels[node] = f"S{node}"
        elif attrs["type"] == "Lot":
            labels[node] = f"L{node}"
    
    # Define colors for the edges
    edge_type_colors = {
        ("Lot", "SameSetup", "Station"): "#8B4D9E",
        ("Station", "ShortSetup", "Lot"): "#DFB825",
        ("Lot", "SameEnergySetup", "Station"): "#70B349",
        ("Station", "ProcessNow", "Lot"): "#DB5C64",
    }
    
    edge_colors = []
    for from_node, to_node, attrs in graph.edges(data=True):
        edge_type = attrs["type"]
        color = edge_type_colors[edge_type]
    
        graph.edges[from_node, to_node]["color"] = color
        edge_colors.append(color)
    
    
    # Draw the graph
    pos = nx.spring_layout(graph, k=2)
    nx.draw_networkx(
        graph,
        pos=pos,
        labels=labels,
        with_labels=True,
        node_color=node_colors,
        edge_color=edge_colors,
        node_size=600,
    )
    plt.show()
    

    Sample Graph