python-3.xpytorchpytorch-geometricgraph-neural-network

PyTorch Geometric custom layer parameters not updating


I am developing a graph neural network using PyTorch Geometric. The idea is to start with multivariate time series, build a graph based on the correlation between those time series and then classify the graph. I have built a CorrelationLayer that computes the adjacency matrix of the graph using the pearson coefficient, and multiplies it for a matrix of trainable weights. This matrix is then passed, along with the time series as node features, to a graph convolution layer (i will add other layers for classifications after the graph convolution but i made a super-simplified version for this question). The problem is that when i try to train the model the weigths of the correlation layer do not update, while the parameters of the graph convolution layer do without any problem)

Here is the code for the correlation layer:

class CorrelationLayer(nn.Module):

    def __init__(self, num_time_series):
      super().__init__()
      self.num_time_series = num_time_series
      self.weights = nn.Parameter(torch.rand((num_time_series, num_time_series)))

    def forward(self, x): 
      correlations = torch.zeros((x.shape[0], x.shape[0]))
      for i in range(x.shape[0]):
        for j in range(i+1, x.shape[0]):
          c, _ = pearsonr(x[i], x[j])
          correlations[i, j] = c
          correlations[j, i] = c
      correlations = correlations * self.weights
      return correlations

And here is the code for the GCN model:

class GCN(nn.Module):

    def __init__(self, num_time_series, ts_length, hidden_channels):
      super(GCN, self).__init__()
      self.corr_layer = CorrelationLayer(num_time_series)
      self.graph_conv = GCNConv(ts_length, hidden_channels)
      return

  def forward(self, x):
      adj = self.corr_layer(x)
      out = self.graph_conv(x, torch_geometric.utils.dense_to_sparse(adj)[0])
      return out

This is the code that i wrote in order to try and test the model, with some sample data:

def train(model, X_train, Y_train):
    model.train()
    for x, y in zip(X_train,Y_train):
        out = model(x)
        print(model.corr_layer.weights)
        print(model.graph_conv.state_dict().values())
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


X =  torch.tensor([
    [
        [0.,1.,2.,3.],
        [1.,2.,3.,4.],
        [0.,6.,3.,1.],
        [3.,2.,1.,0.]
    ],
    [
        [2.,4.,6.,8.],
        [1.,2.,3.,4.],
        [1.,8.,3.,7.],
        [3.,2.,1.,0.]
    ],
    [
        [0.,1.,2.,3.],
        [1.,2.,3.,4.],
        [0.,6.,3.,1.],
        [3.,2.,1.,0.]
    ]
])

Y = torch.tensor([
    [[1.],[1.],[1.],[1.]],
    [[0.],[0.],[0.],[0.]],
    [[1.],[1.],[1.],[1.]]
])

model = GCN(4,4,1)

optimizer = torch.optim.Adam(model.parameters(), lr=0.5)
criterion = torch.nn.MSELoss()

for epoch in range(1, 100):
    train(model, X,Y)

With the prints in the train function we can see that the parameters of the graph_conv layer are updating, while the weights of the correlation layer not.

At the moment my guess is that the problem is in the transition from the adjacency matrix to the sparse version with dense_to_sparse but I am not sure.

Has anyone experienced something similar and have any ideas or suggestions?


Solution

  • Well, even though it's a very pointed and specific question, for anyone passing through here in the future, here's the solution:

    As pointed out by the user thecho7 on the PyTorch forum (https://discuss.pytorch.org/t/pytorch-geometric-custom-layer-parameters-not-updating/170632/2)

    dense_to_sparse contains two tensors that first one is a set of indices of elements and the second one is the value tensor. The index tensor does not contains the gradient where the value tensor has it.

    So in the forward method I changed

    out = self.graph_conv(x, torch_geometric.utils.dense_to_sparse(adj)[0])
    

    to

    out = self.graph_conv(x, torch_geometric.utils.dense_to_sparse(adj)[0], torch_geometric.utils.dense_to_sparse(adj)[1])
    

    and now the weights of the correlation layer update.