I'm currently trying to extend a model that is based on FairSeq/PyTorch. During training I need to train two encoders: one with the target sample, and the original one with the source sample.
So the current forward function looks like this:
def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
And based on this this idea i want something like this:
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
return decoder_out
Is there any way to do this?
Edit: These are the constraints that I have, since I need to extend FairseqEncoderDecoderModel:
@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
Edit 2:
The parameters passed to the forward function in Fairseq can be altered by implementing your own Criterion, see for example CrossEntropyCriterion, where sample['net_input']
is passed to the __call__
function of the model, which invokes the forward
method.
First of all you should always use and define forward
not some other methods that you call on the torch.nn.Module
instance.
Definitely do not overload eval()
as shown by trsvchn as it's evaluation method defined by PyTorch (see here). This method allows layers inside your model to be put into evaluation mode (e.g. specific changes to layers like inference mode for Dropout
or BatchNorm
).
Furthermore you should call it with __call__
magic method. Why? Because hooks and other PyTorch specific stuff is registered that way properly.
Secondly, do not use some external mode
string variable as suggested by @Anant Mittal. That's what train
variable in PyTorch is for, it's standard to differentiate by it whether model is in eval
mode or train
mode.
That being said you are the best off doing it like this:
import torch
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
...
# You could split it into two functions but both should be called by forward
def forward(
self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
if self.training:
return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
You could (and arguably should) split the above into two separate methods, but that's not too bad as the function is rather short and readable that way. Just stick to PyTorch's way of handling things if easily possible and not some ad-hoc solutions. And no, there will be no problem with backpropagation, why would there be one?