javascripttensorflowlstmgradient-descenttfjs-node

How to get LSTM to handle long input samples, without reducing the size of the network?


Background:

I'm using an LSTM (built in tensorflow.js) to generate text based off an input text file. I split up the text file into samples and I originally made each sample have a length of 10. Then I scaled it up to 20, and everything was fine. Finally, I scaled it up to 100, and the LSTM's gradients exploded, and the loss went NaN.

I tried gradient clipping, gradient normalization, weight regularization, batch size reduction, changing the architecture of my model... nothing worked.

The only thing that did help was reducing the size of my LSTM (from 3 512 layers to 3 64 layers), but I want my model to keep its computational power.

The structure of my model is the following:

//chars.length is how many unique characters are in the training set.
const model = tf.sequential({
    layers: [
        tf.layers.lstm({ inputShape: [null, chars.length], units: 512, activation: "relu", returnSequences: true }),
        tf.layers.lstm({ units: 512, activation: "relu", returnSequences: true }),
        tf.layers.lstm({ units: 512, activation: "relu", returnSequences: false }),
        tf.layers.dense({ units: chars.length, activation: "softmax" }),
    ]
});

Then, I compile my model in the following manner:

model.compile({
    optimizer: "adam",
    loss: "categoricalCrossentropy",
    metrics: ["accuracy"],
    clipValue: 0.5,
    clipNorm: 1,
    learningRate: 0.001
})

I've checked my training data, and it is all valid.

Why would my LSTMs gradients still be exploding if the gradients are being clipped? Is something else causing the NaN loss? Any ideas for how to fix it?

(The conventional methods of weight regularization, batch size reduction, and gradient clipping didn't work.)

Full Code: https://gist.github.com/N8python/22c42550ae1cf50236a4c63720cc3ee8


Solution

  • You should try first training on shorter sequences, and then training the same model, with the same weights, on progressively longer sequences. If you start a model with randomly initialized weights on long sequences, it tends to be unstable. So we can start the model with reasonable weights obtained from training on the shorter sequences before moving to the longer sequences.