initializationpytorch

How to decide which mode to use for 'kaiming_normal' initialization


I have read several codes that do layer initialization using nn.init.kaiming_normal_() of PyTorch. Some codes use the fan in mode which is the default. Of the many examples, one can be found here and shown below.

init.kaiming_normal(m.weight.data, a=0, mode='fan_in')

However, sometimes I see people using the fan out mode as seen here and shown below.

if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

Can someone give me some guidelines or tips to help me decide which mode to select? Further I am working on image super resolutions and denoising tasks using PyTorch and which mode will be more beneficial.


Solution

  • According to documentation:

    Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass.

    and according to Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015):

    We note that it is sufficient to use either Eqn.(14) or Eqn.(10)

    where Eqn.(10) and Eqn.(14) are fan_in and fan_out appropriately. Furthermore:

    This means that if the initialization properly scales the backward signal, then this is also the case for the forward signal; and vice versa. For all models in this paper, both forms can make them converge

    so all in all it doesn't matter much but it's more about what you are after. I assume that if you suspect your backward pass might be more "chaotic" (greater variance) it is worth changing the mode to fan_out. This might happen when the loss oscillates a lot (e.g. very easy examples followed by very hard ones).

    Correct choice of nonlinearity is more important, where nonlinearity is the activation you are using after the layer you are initializaing currently. Current defaults set it to leaky_relu with a=0, which is effectively the same as relu. If you are using leaky_relu you should change a to it's slope.