pythontensorflowanacondajaxtensorflow-probability

Unable to import tensorflow_probability.substrates.jax


I am trying to import tensorflow_probability.substrates.jax (specifically to use the distributions) and getting the error shown below (it looks like a self-import). I have installed tensorflow (2.8.2), tensorflow-probability (0.14.0) and jax (0.3.25).

Trying

import tensorflow_probability.substrates.jax as tfp

I get

ImportError: cannot import name 'bijectors' from partially initialized module
'tensorflow_probability.substrates.jax' (most likely due to a circular import)
(/path-to-anaconda3-env/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/__init__.py)

I have tried a few different versions of tensorflow-probability with the same results.


Solution

  • This sounds like it's due to version incompatibility. tensorflow_probability v0.14 was released in Sept 2021 (history), at which point JAX's most recent release was version 0.2.20 (history). JAX has has 36 releases since then, so it's not surprising that some incompatibilities may have arisen.

    I tried in Google Colab and found that the following combination works:

    import tensorflow
    import tensorflow_probability
    import jax
    print(f"{jax.__version__=}")
    print(f"{tensorflow.__version__=}")
    print(f"{tensorflow_probability.__version__=}")
    
    import tensorflow_probability.substrates.jax as tfp
    print("loaded!")
    
    jax.__version__='0.3.25'
    tensorflow.__version__='2.9.2'
    tensorflow_probability.__version__='0.17.0'
    loaded!
    

    Another thing that can cause similar issues is if you are working in a notebook environment and installing new versions of packages that you've already imported. If you're working in notebooks, be sure to restart your Python runtime after you install or update a package.