I got the following U-net architecture causing problems:
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder1 = self.double_conv(in_channels, 64)
self.encoder2 = self.down(64, 128)
self.encoder3 = self.down(128, 256)
self.encoder4 = self.down(256, 512)
self.bottleneck = self.double_conv(512, 1024)
self.decoder4 = self.up(1024, 512)
self.decoder3 = self.up(512, 256)
self.decoder2 = self.up(256, 128)
self.decoder1 = self.up(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) # SAME convolution/padding
def double_conv(self, in_channels, out_channels): # Convo Block
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
def down(self, in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
def forward(self, x):
# Encoder
enc1 = self.encoder1(x) # Output: [1, 64, 256, 256]
print("enc1.shape",enc1.shape)
enc2 = self.encoder2(enc1) # Output: [1, 128, 128, 128]
print("enc2.shape",enc2.shape)
enc3 = self.encoder3(enc2) # Output: [1, 256, 64, 64]
print("enc3.shape",enc3.shape)
enc4 = self.encoder4(enc3) # Output: [1, 512, 32, 32]
print("enc4.shape",enc4.shape)
bottleneck_output = self.bottleneck(enc4) # Output: [1, 1024, 32, 32]
print("bottleneck_output",bottleneck_output.shape)
# Decoder
dec4 = self.decoder4(bottleneck_output)#bottleneck_output) # Output: [1, 512, 64, 64]
print(dec4.shape)
dec4 = torch.cat((dec4, enc4), dim=1) # skip connect, Concatenate: [1, 1024, 64, 64]
dec4 = self.double_conv(1024, 512)(dec4) # Corrected input channels to 1024
dec3 = self.decoder3(dec4) # Output: [1, 256, 128, 128]
dec3 = torch.cat((dec3, enc3), dim=1) # Concatenate: [1, 512, 128, 128]
dec3 = self.double_conv(512, 256)(dec3) # Corrected input channels to 512
dec2 = self.decoder2(dec3) # Output: [1, 128, 256, 256]
dec2 = torch.cat((dec2, enc2), dim=1) # Concatenate: [1, 256, 256, 256]
dec2 = self.double_conv(256, 128)(dec2) # Corrected input channels to 256
dec1 = self.decoder1(dec2) # Output: [1, 64, 512, 512]
dec1 = torch.cat((dec1, enc1), dim=1) # Concatenate: [1, 128, 512, 512]
dec1 = self.double_conv(128, 64)(dec1) # Corrected input channels to 128
return self.final_conv(dec1) # Output: [1, 1, 512, 512]```
When executing in a main method via
unet = UNet(in_channels=1, out_channels=1)
sample_input = torch.randn(1, 1, 256, 256)
output = unet(sample_input)
I get:
enc1.shape torch.Size([1, 64, 256, 256])
enc2.shape torch.Size([1, 128, 128, 128])
enc3.shape torch.Size([1, 256, 64, 64])
enc4.shape torch.Size([1, 512, 32, 32])
bottleneck_output torch.Size([1, 1024, 32, 32])
and the following error:
---> 55 dec4 = self.decoder4(bottleneck_output)
RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 512, 64, 64] to have 1024 channels, but got 512 channels instead
So the problem apparently is the bottleneck_output
shape which does have 1024 channels, but the decoder4
does not seem to recognise it or sth. like that.
I tried matching the dimensions and other things like an align function but nothing worked so far. Also printing out the output shapes didn't really help. Thanks for any hints.
Your problem is with the definition of up
method:
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
ConvTranspose2d
outputs a tensor with out_channels
channels but double_conv
expects an input tensor of in_channels
channels.
You should probably use something like:
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(out_channels, out_channels), # NOTE CHANGE HERE
)