In the decoder part of UNet architecture, Upsampling layer is often followed by a Conv2d. Here is an example:
class UpConv(nn.Module):
def __init__(self, in_chans, out_chans):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode = "bilinear", align_corners=True),
nn.Conv2d(in_chans, out_chans, kernel_size = 1),
)
self.conv = DoubleConv(out_chans*2, out_chans)
...
Could someone please explain why we need the Conv2d?
I read that Conv2d layer allows to learn spatial hierarchies in upsampled feature map. But I don't understand why it is relevant as 2D convolution has been applied during encoding, meaning that spatial information has been learnt already.
My guess is that Conv2d is used to adjust the number of output channels. Is this right?
It is because nn.Upsample()
has no learnable weights and its used to heuristically upsample the image. In your code above, 'bilinear' mode is used to upsample the image which is a resampling technique.
nn.Conv2d()
layer is added after your Upsample()
to allow the model to learn the upsampling process, otherwise, a collection of nn.Upsample() calls would not allow the model to learn the upsampling process, and just heuristically apply the upsampling.