machine-learningpytorchnlptransformer-modelformal-languages

Training difficulties on Transformer seq2seq task using pytorch


I am currently employing a seq2seq task using the vanilla torch.nn.Transformer. My implementation is provided below (SimpleTransformer). I just seem to not be able to make my model output non-trivial sequences and also loss doesn't seem to shrink after flattening. Maybe some more experienced ML researcher could tell me what else to try and what their opinions on those results seem to be?

Task:

The task is the following:

Given a sentence belonging to a defined formal grammar give me the parse tree representation of that sentence. E.g:

Brackets will refer to a certain Non-Terminal rule (i.e. a grammatical object / POS).

Each opening and closing bracket type will be a single token. (e.g. S-opening, S-closing, NP-opening, NP-closing, ..., will be seperate tokens). The text will be tokenized using Byte-Pair-Encoding.

Training params

I trained my model using Cross-Entropy-Loss, AdamW (0.9, 0.98), learning rates of magnitued e-4 and e-5, with a max token size of 512, a batch size of 4 (as GPU memory wasn't big enough for more), with an data-set of ~70 000 example sentences over 3-4 epochs.

Results

My model:

class SimpleTransformer(nn.Module):
def __init__(self, vocab_size: int, ntokens=512, d_model=512, num_layers=6, bidirectional=False, device="cpu"):
    super().__init__()
    self.d_model = d_model
    self.src_embed = nn.Embedding(vocab_size, self.d_model)
    self.tgt_embed = nn.Embedding(vocab_size, self.d_model)
    self.positional_encoder = PositionalEncoding(d_model=self.d_model, max_len=ntokens)
    self.model = Transformer(d_model=self.d_model, batch_first=True, num_encoder_layers=num_layers,
                             num_decoder_layers=num_layers)
    self.bidirectional = bidirectional
    self.generator = Generator(hidden_size=self.d_model, vocab_size=vocab_size) # Just a fc layer
    self.device = device

def forward(self, in_ids, l_ids, in_masks, l_masks):
    in_ids = self.src_embed(in_ids.long()) * math.sqrt(self.d_model)  # scale by sqrt of dmodel
    in_ids = self.positional_encoder(in_ids)

    l_ids = self.tgt_embed(l_ids.long()) * math.sqrt(self.d_model)
    l_ids = self.positional_encoder(l_ids)

    # Create Masks
    src_seq_len = in_ids.size(1)
    tgt_seq_len = l_ids.size(1)
    src_mask = torch.zeros(src_seq_len, src_seq_len, device=self.device).type(torch.bool)
    if not self.bidirectional:
        tgt_mask = torch.triu(torch.full((tgt_seq_len, tgt_seq_len), float('-inf'), device=self.device), diagonal=1)
    else:
        tgt_mask = torch.zeros(tgt_seq_len, tgt_seq_len, device=self.device).type(torch.bool)
    in_masks = in_masks == 0.0 # in_masks will mask special pad_tokens
    l_masks = l_masks == 0.0 # l_masks will mask special pad_tokens

    out = self.model(src=in_ids, tgt=l_ids,
                     src_mask=src_mask, tgt_mask=tgt_mask,
                     src_key_padding_mask=in_masks,
                     tgt_key_padding_mask=l_masks)
    return self.generator(out)

I tried with

I saw that the outputs will be like:

(s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s (s ...

No matter the input... I just don't know whether the task is to difficult to handle or whether I made a mistake somewhere in the process. I can't get my model to output non-trivial patterns. I'm quite new in the machine learning business so it could certainly be that it might be a stupid mistake somewhere. Maybe I stopped too early with the training or some hyper-parameters are wrong?

Comparative Attempts

Additionally, I used the facebook/fairseq toolkit to train a model on the same task. It's performance was better with loss decreasing significantly:

I also trained the pre-trained bart-base Model from the huggingface library with the identical training script and training data. It's performce was way better than both of the previous models.


Solution

  • Ok I think I could solve this. For people experiencing same problems:

    What did not solve my problem, but might help you:

    What did in the end solve my problem:

    Employing the learning rate scheduler used in the paper. It will linearly increase the learning rate untill step 4000 and then decrease employing the sqrt function that can be found here.

    TLDR

    Use this wrapper around your optimizer: (it will change the lr according to the "Attention is all you need" paper)

    class NoamOptim(object):
    """ Optimizer wrapper for learning rate scheduling.
    """
    
    def __init__(self, optimizer, d_model, factor=2, n_warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.factor = factor
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0
    
    def zero_grad(self):
        self.optimizer.zero_grad()
    
    def step(self):
        self.n_steps += 1
        lr = self.get_lr()
        for p in self.optimizer.param_groups:
            p['lr'] = lr
        self.optimizer.step()
    
    def get_lr(self):
        return self.factor * (
                self.d_model ** (-0.5)
                * min(self.n_steps ** (-0.5), self.n_steps * self.n_warmup_steps ** (-1.5))
        )## Heading ##