When I use this class to build a GATConv network to predict the class of a graph I get a dimension error. How would I resolve this.
class GraphClassifier(nn.Module):
def __init__(self, in_feats, hidden_size, num_classes):
super(GraphClassifier, self).__init__()
self.conv1 = GATv2Conv(in_feats, hidden_size, num_heads=4)
self.conv2 = GATv2Conv(4*hidden_size, hidden_size, num_heads=4)
self.conv3 = GATv2Conv(4*hidden_size, hidden_size, num_heads=4)
self.conv4 = GATv2Conv(4*hidden_size, hidden_size, num_heads=1)
self.classify = nn.Linear(hidden_size, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, g, inputs):
h = self.conv1(g, inputs)
h = F.elu(h)
h = self.dropout(h)
h = self.conv2(g, h)
h = F.elu(h)
h = self.dropout(h)
h = self.conv3(g, h)
h = F.elu(h)
h = self.dropout(h)
h = self.conv4(g, h)
h = F.elu(h)
h = self.dropout(h)
with g.local_scope():
g.ndata['h'] = h
# Calculate graph representation by max pooling readout.
hg = dgl.max_nodes(g, 'h')
return self.classify(hg)
This is the code used for training:
import torch.nn.functional as F
model = GraphClassifier(dataset.dim_nfeats, 16, dataset.gclasses)
opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
# Instantiate a predefined optimizer from torch.optim - this
# is the method that will be used to perform Gradient Descent.
# We recommend using Adam.
for epoch in range(400):
model.train() # Set the model to train mode to allow gradients to be calculated
cumulative_loss_train = 0.0 # Use this accumulate loss over an epoch and print it out at the end
# to get an estimate for how well your model is performing
for batched_graph, labels in trainloader:
features = batched_graph.ndata['attr']
logits = model(batched_graph, batched_graph.ndata["attr"])
loss = F.cross_entropy(logits, labels) # Compute cross entropy loss.
opt.zero_grad() # Reset gradients for the next batch, since they accumulate by default
loss.backward() # Backprop
opt.step() # Update parameters
cumulative_loss_train += loss.item()
This is the error message: RuntimeError: mat1 and mat2 shapes cannot be multiplied (7140x16 and 64x64) The error refers to this line: h = self.conv2(g, h) in the forward function.
After h = self.conv1(g, inputs)
, the dims of h is: [node_num, num_heads, hidden_size]
, so you can reshape h by
h = h.reshape(h.shape[0], -1)
h = self.conv2(g, h)