I'm trying to run a model written in jax, https://github.com/lindermanlab/S5. However, I ran into some error that says
Traceback (most recent call last):
File "/Path/run_train.py", line 101, in <module>
train(parser.parse_args())
File "/Path/train.py", line 144, in train
state = create_train_state(model_cls,
File "/Path/train_helpers.py", line 135, in create_train_state
params = variables["params"].unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'
I tried to replicate this error by
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
model = nn.Dense(features=3)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
params_unfrozen = flax.traverse_util.unfreeze(params)
And the error reads:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'
I'm using:
flax 0.7.4
jax 0.4.13
jaxlib 0.4.13+cuda12.cudnn89
I think this is an issue relating to the version of flax, but does anyone know what exactly is going on? Any help is appreciated. Let me know if you need any further information
unfreeze
is a method of Flax's FrozenDict
class: (See FrozenDict.unfreeze
). It appears that you have passed a Python dict
where a FrozenDict
is expected.
To fix this, you should ensure that variables['params']
is a FrozenDict
, not a dict
.
Regarding the error in your attempted replication: flax.traverse_util
does not define an unfreeze
function, but this seems unrelated to the original problem.