Is there a way to draw the graph "features" with networkx
?
(-1, 0 and 1 in my example)
With my Python code, currently it does not draw it. See code below:
import torch
from torch_geometric.data import Data
edge_index = torch.tensor(
[
[0, 1, 1, 2],
[1, 0, 2, 1]],
dtype=torch.long)
#features: one feature for each node
#node features <------
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
do_visualize=1
if do_visualize:
import networkx as nx
import matplotlib.pyplot as plt
edge_list = edge_index.t().tolist()
G = nx.Graph()
G.add_edges_from(edge_list)
G.add_nodes_from(range(x.size(0)))
pos = nx.spring_layout(G, seed=1) # fixed seed for reproducibility
fig=plt.figure(figsize=(6, 6))
nx.draw_networkx(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=800)
plt.title("Visualization of the Social Network Graph")
plt.axis('off')
plt.show()
Nodes doesn't have information about your "features"
Primitive method: create dict
with features = {nodename: text, ...}
and use labels=features
in draw_networkx()
or use draw_networkx_labels()
features = {
0: "[-1]",
1: "[0]",
2: "[1]",
}
# ...
nx.draw_networkx(G, pos, with_labels=True, ..., labels=features)
# or (with `with_labels=False`)
nx.draw_networkx(G, pos, with_labels=False) # without labels=features
nx.draw_networkx_labels(G, pos, labels=features)
It may need better method to convert x
to this dictionary.
I created only
features = dict(enumerate(x.int().tolist()))
or with extra text
features = {index: f"x = {item}" for index, item in enumerate(x.int().tolist())}
It needs to use .int()
because it was creating strings like [1.0]
.
import torch
from torch_geometric.data import Data
edge_index = torch.tensor(
[
[0, 1, 1, 2],
[1, 0, 2, 1]],
dtype=torch.long)
#features: one feature for each node
#node features <------
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
#print(data.x)
features = {
0: "[-1]",
1: "[0]",
2: "[1]",
}
#features = {index: str(item) for index, item in enumerate(x.int().tolist())}
#features = {index: item for index, item in enumerate(x.int().tolist())}
features = dict(enumerate(x.int().tolist()))
# or with extra text
#features = {index: f"x = {item}" for index, item in enumerate(x.int().tolist())}
print(features)
do_visualize = True
if do_visualize:
import networkx as nx
import matplotlib.pyplot as plt
edge_list = edge_index.t().tolist()
G = nx.Graph()
G.add_edges_from(edge_list)
G.add_nodes_from(range(x.size(0)))
pos = nx.spring_layout(G, seed=1) # fixed seed for reproducibility
fig=plt.figure(figsize=(6, 6))
nx.draw_networkx(G, pos, node_color='lightblue', edge_color='gray', node_size=800, with_labels=True, labels=features)
# or `with_labels=False`
#nx.draw_networkx(G, pos, node_color='lightblue', edge_color='gray', node_size=800, with_labels=False)
#nx.draw_networkx_labels(G, pos, labels=features)
plt.title("Visualization of the Social Network Graph")
plt.axis('off')
plt.show()
Doc: draw_networkx_labels, draw_networkx_edge_labels
See also how to convert data
to graph
charts - How to visualize a torch_geometric graph in Python? - Stack Overflow
G = torch_geometric.utils.to_networkx(data, to_undirected=True)
but it also need to add labels separatelly.