recurrent-neural-networkjaxgruflax

How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)


I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows:

import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import initializers

class RNN(nn.Module):
    n_RNN_units: int

    @nn.compact
    def __call__(self, carry, inputs):
        
        carry, outputs = nn.GRUCell()(carry, inputs)
        
        return carry, outputs
    
    def init_state(self):
        
        return nn.GRUCell.initialize_carry((), (), self.n_RNN_units, init_fn = initializers.zeros)

# instantiate an RNN (GRU) model
n_RNN_units = 200
model = RNN(n_RNN_units = n_RNN_units)

# initialize the parameters of the model (weights and biases)
data_dim = 20
params = model.init(carry = np.empty((n_RNN_units,)), inputs = np.empty((data_dim,)), rngs = {'params': random.PRNGKey(1)})

Unfortuantely for me, the FrozenDict params created by model.init only contains the weight and biases of the GRU, not the initial hidden state (carry). Is there a way that I can tell model.init 1) that I also want to learn the initial hidden state and 2) specify the initializer function for the initial hidden state.

Alternatively, if there is a better way to do this that does not involve using model.init, feel free to suggest that.

Thanks in advance


Solution

  • You can use self.param to register a tensor as parameters:

    @nn.compact
    def __call__(self, inputs, carry=None):
        if carry is None:
            # Learnable initial carry
            carry = self.param('carry_init', lambda rng, shape: jnp.zeros(shape), (self.n_RNN_units,))
        carry, outputs = nn.GRUCell()(carry, inputs)
        return carry, outputs
    

    Now carry_init is in model parameters after model.init(rng, inputs, None).

    What happen now is that model.apply takes parameters params with carry_init on it so gradients w.r.t to it will be computed as usual with grad.

    More precisely when you are making a prediction of a sequence, you have to start your calls with carry, outputs = model.apply(params, inputs). It will use carry_init in params then for the following calls use carry, outputs = model.apply(params, inputs, carry). It will use carry now and the carry_init is indirectly on the computation graph of outputs and carry as the initial carry so you can propagate gradient on it. However you should care about potentially heavy gradient vanishing for it if you have long sequences so you may consider using all the values (especially the first) of your sequences to compute the loss or adapting a dedicated learning rate based on sequence length.

    Details of linen.Module.param in the Flax documntation Managing Parameters and State.