I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows:
import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import initializers
class RNN(nn.Module):
n_RNN_units: int
@nn.compact
def __call__(self, carry, inputs):
carry, outputs = nn.GRUCell()(carry, inputs)
return carry, outputs
def init_state(self):
return nn.GRUCell.initialize_carry((), (), self.n_RNN_units, init_fn = initializers.zeros)
# instantiate an RNN (GRU) model
n_RNN_units = 200
model = RNN(n_RNN_units = n_RNN_units)
# initialize the parameters of the model (weights and biases)
data_dim = 20
params = model.init(carry = np.empty((n_RNN_units,)), inputs = np.empty((data_dim,)), rngs = {'params': random.PRNGKey(1)})
Unfortuantely for me, the FrozenDict params created by model.init only contains the weight and biases of the GRU, not the initial hidden state (carry). Is there a way that I can tell model.init 1) that I also want to learn the initial hidden state and 2) specify the initializer function for the initial hidden state.
Alternatively, if there is a better way to do this that does not involve using model.init, feel free to suggest that.
Thanks in advance
You can use self.param
to register a tensor as parameters:
@nn.compact
def __call__(self, inputs, carry=None):
if carry is None:
# Learnable initial carry
carry = self.param('carry_init', lambda rng, shape: jnp.zeros(shape), (self.n_RNN_units,))
carry, outputs = nn.GRUCell()(carry, inputs)
return carry, outputs
Now carry_init
is in model parameters after model.init(rng, inputs, None)
.
What happen now is that model.apply
takes parameters params
with carry_init
on it so gradients w.r.t to it will be computed as usual with grad
.
More precisely when you are making a prediction of a sequence, you have to start your calls with carry, outputs = model.apply(params, inputs)
. It will use carry_init
in params
then for the following calls use carry, outputs = model.apply(params, inputs, carry)
. It will use carry
now and the carry_init
is indirectly on the computation graph of outputs and carry as the initial carry so you can propagate gradient on it. However you should care about potentially heavy gradient vanishing for it if you have long sequences so you may consider using all the values (especially the first) of your sequences to compute the loss or adapting a dedicated learning rate based on sequence length.
Details of linen.Module.param
in the Flax documntation Managing Parameters and State.