python-3.xneural-networkpytorchtransformer-modelseq2seq

PyTorch: Different Forward Methods for Train and Test/Validation


I'm currently trying to extend a model that is based on FairSeq/PyTorch. During training I need to train two encoders: one with the target sample, and the original one with the source sample.

So the current forward function looks like this:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

And based on this this idea i want something like this:

def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out

Is there any way to do this?

Edit: These are the constraints that I have, since I need to extend FairseqEncoderDecoderModel:

@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder) 

Edit 2: The parameters passed to the forward function in Fairseq can be altered by implementing your own Criterion, see for example CrossEntropyCriterion, where sample['net_input'] is passed to the __call__ function of the model, which invokes the forward method.


Solution

  • First of all you should always use and define forward not some other methods that you call on the torch.nn.Module instance.

    Definitely do not overload eval() as shown by trsvchn as it's evaluation method defined by PyTorch (see here). This method allows layers inside your model to be put into evaluation mode (e.g. specific changes to layers like inference mode for Dropout or BatchNorm).

    Furthermore you should call it with __call__ magic method. Why? Because hooks and other PyTorch specific stuff is registered that way properly.

    Secondly, do not use some external mode string variable as suggested by @Anant Mittal. That's what train variable in PyTorch is for, it's standard to differentiate by it whether model is in eval mode or train mode.

    That being said you are the best off doing it like this:

    import torch
    
    
    class Network(torch.nn.Module):
        def __init__(self):
            super().__init__()
            ...
    
        # You could split it into two functions but both should be called by forward
        def forward(
            self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
        ):
            encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
            if self.training:
                return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
            autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
            concat = some_concatination_func(encoder_out, autoencoder_out)
            return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    

    You could (and arguably should) split the above into two separate methods, but that's not too bad as the function is rather short and readable that way. Just stick to PyTorch's way of handling things if easily possible and not some ad-hoc solutions. And no, there will be no problem with backpropagation, why would there be one?