Hello, I have used GraphSAGE to do Node embedding. The function I chose to use for aggregate is LSTM with the library of PyG for Graph Neural Network, the arguments it needs are follow:
Input 1: Node features (|V|, F_in) - Here I use Node coordinates x-y in 2D plane (V x 2) and already normalized into the range of [0, 1] e.g.
x y
0 0.374540 0.598658
1 0.950714 0.156019
2 0.731994 0.155995
Input 2: Edge indices (2, |E|) - Adjacency matrix (V x V) but retrieves only the edge into (2, |E|) from the original adjacency matrix I have e.g.
idx 0 1 2
0 [[0, 1, 1],
1 [1, 0, 1],
2 [1, 1, 0]]
Above we have a shape (V x V) with 6 edges in the graph. We had to transform it a bit to accommodate PyG's use of shape (2, |E|) and I'd like to call it edge_index
where edges is (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1).
idx 0 1 2 3 4 5
0 [[0, 0, 1, 1, 2, 2],
1 [1, 2, 0, 2, 0, 1]]
Output: Node features (|V|, F_out) - Similar to Node coordinates, but they are not in 2D anymore, they are in a new embedding dimension with F_out dimensions.
My problem is that when using the LSTM aggregator it is forced to sort edge_index
(Edge indices in input2) otherwise it will show an error saying ValueError: Can not perform aggregation since the 'index' tensor. is not sorted.
So I have to do sorting gives it with the following command:
# inside def __init__()
self.graph_sage=SAGEConv(in_channels=2, out_channels=hidden_dim, aggr='lstm')
# inside def forward()
sorted_edge_index, _ = torch.sort(edge_index, dim=1) # for LSTM
x = self.graph_sage(coord.view(-1, 2), sorted_edge_index) # using GraphSAGE
The sorted_edge_index
tensor will look like this after sorting.
idx 0 1 2 3 4 5
0 [[0, 0, 1, 1, 2, 2],
1 [0, 0, 1, 1, 2, 2]]
I noticed that in my full-mesh graph of 3 nodes connected, when I sorted it, the edges could be reinterpreted as (0, 0), (0, 0), (1, 1), (1, 1), (2, 2), (2, 2) which made me curious. And my questions are the following 2 things.
edge_index
?edge_index
like this, how will my model know which nodes are connected? Because all the original edge relationship pairs are gone. It's like sending edge pairs that don't exist in the graph as input. Will this be a disadvantage?I have tried doing the above and it ran fine. But I have some doubts and hope someone knowledgeable will help clarify things for a beginner like me. And I sincerely hope this question can be useful to other students studying GNN as well.
Sort by row = False
from torch_geometric.utils import sort_edge_index
sorted_edge_index = sort_edge_index(edge_index, num_nodes=self.num_nodes, sort_by_row=False)
x = self.graph_sage(coord.view(-1, 2), sorted_edge_index)
https://github.com/pyg-team/pytorch_geometric/discussions/8908