I'm doing my first graph convolutional neural network project with torch_geometric
. I want to visualize the last layer node embeddings of my model and don't know how I should get it.
I trained my model on the CiteSeer dataset. You can get the full dataset as easily as this:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root="data/Planetoid", name='CiteSeer', transform=NormalizeFeatures())
My model is a simple two-layered model as this:
class GraphClassifier(torch.nn.Module):
def __init__(self, dataset, hidden_dim):
super(GraphClassifier, self).__init__()
self.conv1 = GCNConv(dataset.num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
return F.log_softmax(x, dim=1)
If you print my model you will get this:
model = GraphClassifier(dataset, 64)
print(model)
>>>
GraphClassifier(
(conv1): GCNConv(3703, 64)
(conv2): GCNConv(64, 6)
)
My model is trained successfully. I only want to visualize its last-layer node embeddings. To visualize that I have this function to use:
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
# emb: (nNodes, hidden_dim)
# node_type: (nNodes,). Entries are torch.int64 ranged from 0 to num_class - 1
def visualize(emb: torch.tensor, node_type: torch.tensor):
z = TSNE(n_components=2).fit_transform(emb.detach().cpu().numpy())
plt.figure(figsize=(10,10))
plt.scatter(z[:, 0], z[:, 1], s=70, c=node_type, cmap="Set2")
plt.show()
I don't know how I should extract emb
and node_type
from my model to give to the visualize
function. emb
is the last layer of node embeddings of the model. How can I get these from my model?
It is solve by changing the model to this:
class GraphClassifier(torch.nn.Module):
def __init__(self, dataset, hidden_dim):
super(GraphClassifier, self).__init__()
self.conv1 = GCNConv(dataset.num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, dataset.num_classes)
def forward(self, data, do_visualize=False):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
if do_visualize: # NEW LINE
visualize(x, data.y) # NEW LINE
return F.log_softmax(x, dim=1)
Now if you call the forward function with do_visualize=Ture
it will visualize. like this:
model = GraphClassifier(dataset, hidden_dim)
model.to(device)
model(dataset[0].to(device), do_visualize=True)