pythonpytorchneural-networklstmpytorch-geometric

Why does LSTM Aggregation in PyG need to sort edge_index?


Hello, I have used GraphSAGE to do Node embedding. The function I chose to use for aggregate is LSTM with the library of PyG for Graph Neural Network, the arguments it needs are follow:

Input 1: Node features (|V|, F_in) - Here I use Node coordinates x-y in 2D plane (V x 2) and already normalized into the range of [0, 1] e.g.

          x         y
0  0.374540  0.598658
1  0.950714  0.156019
2  0.731994  0.155995

Input 2: Edge indices (2, |E|) - Adjacency matrix (V x V) but retrieves only the edge into (2, |E|) from the original adjacency matrix I have e.g.

idx   0  1  2
0   [[0, 1, 1], 
1    [1, 0, 1], 
2    [1, 1, 0]]

Above we have a shape (V x V) with 6 edges in the graph. We had to transform it a bit to accommodate PyG's use of shape (2, |E|) and I'd like to call it edge_index where edges is (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1).

idx   0  1  2  3  4  5
0   [[0, 0, 1, 1, 2, 2],
1    [1, 2, 0, 2, 0, 1]]

Output: Node features (|V|, F_out) - Similar to Node coordinates, but they are not in 2D anymore, they are in a new embedding dimension with F_out dimensions.

My problem is that when using the LSTM aggregator it is forced to sort edge_index (Edge indices in input2) otherwise it will show an error saying ValueError: Can not perform aggregation since the 'index' tensor. is not sorted.

So I have to do sorting gives it with the following command:

# inside def __init__()
self.graph_sage=SAGEConv(in_channels=2, out_channels=hidden_dim, aggr='lstm')

# inside def forward()
sorted_edge_index, _ = torch.sort(edge_index, dim=1)  # for LSTM
x = self.graph_sage(coord.view(-1, 2), sorted_edge_index)  # using GraphSAGE

The sorted_edge_index tensor will look like this after sorting.

idx   0  1  2  3  4  5
0   [[0, 0, 1, 1, 2, 2],
1    [0, 0, 1, 1, 2, 2]]

I noticed that in my full-mesh graph of 3 nodes connected, when I sorted it, the edges could be reinterpreted as (0, 0), (0, 0), (1, 1), (1, 1), (2, 2), (2, 2) which made me curious. And my questions are the following 2 things.

  1. Why does LSTM need to sort the edge_index?
  2. After I sort edge_index like this, how will my model know which nodes are connected? Because all the original edge relationship pairs are gone. It's like sending edge pairs that don't exist in the graph as input. Will this be a disadvantage?

I have tried doing the above and it ran fine. But I have some doubts and hope someone knowledgeable will help clarify things for a beginner like me. And I sincerely hope this question can be useful to other students studying GNN as well.


Solution

  • Sort by row = False

    from torch_geometric.utils import sort_edge_index
    
    sorted_edge_index = sort_edge_index(edge_index, num_nodes=self.num_nodes, sort_by_row=False)
    x = self.graph_sage(coord.view(-1, 2), sorted_edge_index)
    

    https://github.com/pyg-team/pytorch_geometric/discussions/8908