My code is as follows:
!pip install flax
init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params']
print(f'Model parameters: {n_params(init_params):,}')
optim = flax.optim.Adam(lr=1e-4).create(init_params)
However it shows the following error:
AttributeError: module 'flax' has no attribute 'optim'
Even though I have seen documentation of optim
attribute in flax
module. How to fix it?
You can temporarily solve the issue by downgrading flax version from 0.6.0 to 0.5.1 at the moment.
pip install flax==0.5.1