I am doing a project with RNNs using jax and flax and I have noticed some behavior that I do not really understand.
My code is basically an optimization loop where the user provides the initial parameters for the system they want to optimize. This system is divided onto several time steps. He feeds the initial input into the first time step of the the system, gets a certain output, feeds this output into a RNN which returns the parameters for the following time step and so on. Then it is optimized using adam (particularly using optax).
Now the user inputs his initial paramaters as a dict and then there is a function called prepare_parameters_from_dict
that basically converts this dict into a list of lists ( or a list of jnp arrays for that matter ).
My question/observation is when I make this function return a list of jnp.arrays instead of a list of lists, the property I am optimizing is an order of magnitude worse!
For example, using a list of lists outputs 0.9997 and a list of jnp.arrays outputs 0.998 (the closer to one the better).
Noting: the RNN output a list of jnp.arrays (it is using flax linnen) and everything in the code remains the same.
Here are said function:
Outputing list of lists:
def prepare_parameters_from_dict(params_dict):
"""
Convert a nested dictionary of parameters to a flat list and record shapes.
Args:
params_dict: Nested dictionary of parameters.
Returns:
tuple: Flattened parameters list and list of shapes.
"""
res = []
shapes = []
for value in params_dict.values():
flat_params = jax.tree_util.tree_leaves(value)
res.append(flat_params)
shapes.append(len(flat_params))
return res, shapes
Using list of jnp.arrays:
def prepare_parameters_from_dict(params_dict):
"""
Convert a nested dictionary of parameters to a flat list and record shapes.
Args:
params_dict: Nested dictionary of parameters.
Returns:
tuple: Flattened parameters list and list of shapes.
"""
res = []
shapes = []
for value in params_dict.values():
flat_params = jax.tree_util.tree_leaves(value)
res.append(jnp.array(flat_params))
shapes.append(jnp.array(flat_params).shape[0])
return res, shapes
and this is an example of the users input initial params:
initial_params = {
"param1": {
"gamma": 0.1,
"delta": -3 * jnp.pi / 2,
}
}
The rest of the code remains exactly the same for both.
After optimization if for example there were five time steps, this is how the final optimized params for each time step would look like:
using list of jnp.arrays:
[[Array([ 0.1 , -4.71238898], dtype=float64)],
[Array([-0.97106537, -0.03807388], dtype=float64)],
[Array([-1.17050792, -0.01463591], dtype=float64)],
[Array([-0.77229875, -0.0124556 ], dtype=float64)],
[Array([-1.56113376, -0.01103598], dtype=float64)]]
using list of lists:
[[ [0.1 , -4.71238898] ]],
[Array([-0.97106537, -0.03807388], dtype=float64)],
[Array([-1.17050792, -0.01463591], dtype=float64)],
[Array([-0.77229875, -0.0124556 ], dtype=float64)],
[Array([-1.56113376, -0.01103598], dtype=float64)]]
Would such a difference in behavior be due to how jax handles grad and jit and others with lists compared to jnp.arrays or am I missing something?
The main operative difference between these two cases is that Python floats are treated as weakly-typed, meaning that the list version of your code could result in operations being performed at a lower precision. For example:
In [1]: import jax
In [2]: import jax.numpy as jnp
In [3]: jax.config.update('jax_enable_x64', True)
In [4]: list_values = [0.1, -4.71238898]
In [5]: array_values = jax.numpy.array(list_values)
In [6]: x = jax.numpy.float32(1.0)
In [7]: x + list_values[1]
Out[7]: Array(-3.712389, dtype=float32)
In [8]: x + array_values[1]
Out[8]: Array(-3.71238898, dtype=float64)
Notice that the array version leads to higher-precision computations in this case. If I had to guess what the main difference is in your two runs, I'd guess something to do with the precision implied by strict vs weak types.