Hello what is a good way to visualize a pyg HeteroData object ? (defined similarly: https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html#creating-heterogeneous-gnns )
I tried with networkx but I think it is restricted to homogeneous graph ( it is possible to convert it but it is much less informative).
g = torch_geometric.utils.to_networkx(data.to_homogeneous(), to_undirected=False )
Did anyone try to do it with other python lib (matplotlib) or js (sigma.js/d3.js)?
Any docs link you can share?
You can do this with networkx, but you need to do some coding to tell it how to format the nodes and edges.
# Simple example of network x rendering with colored nodes and edges
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
graph = to_networkx(data, to_undirected=False)
# Define colors for nodes and edges
node_type_colors = {
"Station": "#4599C3",
"Lot": "#ED8546",
}
node_colors = []
labels = {}
for node, attrs in graph.nodes(data=True):
node_type = attrs["type"]
color = node_type_colors[node_type]
node_colors.append(color)
if attrs["type"] == "Station":
labels[node] = f"S{node}"
elif attrs["type"] == "Lot":
labels[node] = f"L{node}"
# Define colors for the edges
edge_type_colors = {
("Lot", "SameSetup", "Station"): "#8B4D9E",
("Station", "ShortSetup", "Lot"): "#DFB825",
("Lot", "SameEnergySetup", "Station"): "#70B349",
("Station", "ProcessNow", "Lot"): "#DB5C64",
}
edge_colors = []
for from_node, to_node, attrs in graph.edges(data=True):
edge_type = attrs["type"]
color = edge_type_colors[edge_type]
graph.edges[from_node, to_node]["color"] = color
edge_colors.append(color)
# Draw the graph
pos = nx.spring_layout(graph, k=2)
nx.draw_networkx(
graph,
pos=pos,
labels=labels,
with_labels=True,
node_color=node_colors,
edge_color=edge_colors,
node_size=600,
)
plt.show()