machine-learningclassificationadjacency-matrixgraph-neural-networkgnn

Classification using Graph Neural Network


I am working on a fraud detection project using GNN. My graph has banking codes (SWIFT BIC codes) as nodes and the edges represent transactions. Below are the shapes of my tensors:

I tried numerous attempts but currently am following this tutorial: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html

Below is my GNN code:

class GCNLayer(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        # Num neighbours = number of incoming edges
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats)
        print("node_feats ",node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbours
        return node_feats

layer = GCNLayer(c_in=6, c_out=210)
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])

with torch.no_grad():
    out_feats = layer(node_features_tensor, adjacency_matrix_tensor)

print("Adjacency matrix", adjacency_matrix_tensor)
print("Input features", node_features_tensor)
print("Output features", out_feats)

But whatever I try I keep on getting dimension errors during multiplication: "RuntimeError: mat1 and mat2 shapes cannot be multiplied (210x6 and 2x2)".

I know that we are trying to multiply node_Features_tensor (210,6) by adjacency_matrix_tensor (210,210) but I have been stuck on this for days!

I tried multiple implementations of GNN/GCN. I am expecting to be able to train my model.


Solution

  • The issue here is that you assign a 2x2 weight matrix and 2x1 bias to your 6x210 linear transformation. If you want to do the same as in the tutorial, you need to change this:

    layer = GCNLayer(c_in=6, c_out=6)
    layer.projection.weight.data = torch.eye(6)
    layer.projection.bias.data = torch.zeros(6)
    

    See torch.nn.Linear documentation.