nlppytorchbert-language-modeltransformer-modellanguage-model

BERT: Weights of input embeddings as part of the Masked Language Model


I looked through different implementations of BERT's Masked Language Model. For pre-training there are two common versions:

  1. Decoder would simply take the final embedding of the [MASK]ed token and pass it throught a linear layer (without any modifications):
    class LMPrediction(nn.Module):
        def __init__(self, hidden_size, vocab_size):
            super().__init__()
            self.decoder = nn.Linear(hidden_size, vocab_size, bias = False)
            self.bias = nn.Parameter(torch.zeros(vocab_size))
            self.decoder.bias = self.bias
        def forward(self, x):
             return self.decoder(x)
  1. Some implementations would use the weights of the input embeddings as weights of the decoder-linear-layer:
    class LMPrediction(nn.Module):
        def __init__(self, hidden_size, vocab_size, embeddings):
            super().__init__()
            self.decoder = nn.Linear(hidden_size, vocab_size, bias = False)
            self.bias = nn.Parameter(torch.zeros(vocab_size))
            self.decoder.weight = embeddings.weight ## <- THIS LINE
            self.decoder.bias = self.bias
        def forward(self, x):
             return self.decoder(x)

Which one is correct? Mostly, I see the first implementation. However, the second one makes sense as well - but I cannot find it mentioned in any papers (I would like to see if the second version is somehow superior to the first one)


Solution

  • For those who are interested, it is called weight tying or joint input-output embedding. There are two papers that argue for the benefit of this approach: