pythonpytorchgabor-filter

Runtime error: mat1 dim 1 must match mat2 dim 0


I am running one classification program using GaborNet. Part of my code is

class module(nn.Module):
        def __init__(self):
            super(module, self).__init__()
            self.g0 = modConv2d(in_channels=3, out_channels=32, kernel_size=(11, 11), stride=1)
            self.c1 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(2, 2),stride=1)
            self.c2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(2, 2),stride=1)

        
            #x = x.view(x.size(0), -1)
            
            #x = x.view(1, *x.shape)
            #x=x.view(-1,512*12*12)
          
            x = F.relu(self.fc1(x))
            print(x.shape)
            x = F.relu(self.fc2(x))
            print(x.shape)
            x = self.fc3(x)
            return x 

I am getting this error at this position : x = F.relu(self.fc1(x)

and the error is : RuntimeError: mat1 dim 1 must match mat2 dim 0

However the shape of the input image in the subsequent layers are till fc1 is:

torch.Size([64, 3, 150, 150])
torch.Size([64, 32, 140, 140])
torch.Size([64, 32, 70, 70])

Solution

  • You were on the right track, you indeed need to reshape your data just after the convolution layers, and before proceeding with the fully connected layers.

    The best approach for flattening your tensor is to use nn.Flatten otherwise you might end up disrupting the batch size. The last spatial layer outputs a shape of (64, 128, 3, 3) and once flattened this tensor has a shape of (64, 1152) where 1152 = 128*3*3. Therefore your first fully connected layer should have 1152 neurons.

    Something like this should work:

    class GaborNN(nn.Module):
        def __init__(self):
            super().__init__()
            ...
            self.fc1 = nn.Linear(in_features=1152, out_features=128)
            self.fc2 = nn.Linear(in_features=128, out_features=128)
            self.fc3 = nn.Linear(in_features=128, out_features=7)
            self.flatten = nn.Flatten()
    
        def forward(self, x):
            ...
            x = self.flatten(x)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x