pythonattributeerrorjaxflax

AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'


I'm trying to run a model written in jax, https://github.com/lindermanlab/S5. However, I ran into some error that says

   Traceback (most recent call last):
  File "/Path/run_train.py", line 101, in <module>
    train(parser.parse_args())
  File "/Path/train.py", line 144, in train
    state = create_train_state(model_cls,
  File "/Path/train_helpers.py", line 135, in create_train_state
    params = variables["params"].unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'

I tried to replicate this error by

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn

model = nn.Dense(features=3)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
params_unfrozen = flax.traverse_util.unfreeze(params)

And the error reads:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'

I'm using:

flax 0.7.4
jax 0.4.13
jaxlib 0.4.13+cuda12.cudnn89

I think this is an issue relating to the version of flax, but does anyone know what exactly is going on? Any help is appreciated. Let me know if you need any further information


Solution

  • unfreeze is a method of Flax's FrozenDict class: (See FrozenDict.unfreeze). It appears that you have passed a Python dict where a FrozenDict is expected.

    To fix this, you should ensure that variables['params'] is a FrozenDict, not a dict.

    Regarding the error in your attempted replication: flax.traverse_util does not define an unfreeze function, but this seems unrelated to the original problem.