installationgpufailed-installationjax

Unable to Install Specific JAX jaxlib GPU version


I'm trying to install a particular version of jaxlib to work with my CUDA and cuDNN versions. Following the README, I'm trying

pip install --upgrade jax jaxlib==0.1.52+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

This returns the following error:

ERROR: Requested jaxlib==0.1.52+cuda101 from https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.52%2Bcuda101-cp37-none-manylinux2010_x86_64.whl has different version in metadata: '0.1.52'

Does anyone know what causes this or how to get around the error?


Solution

  • This error appears to be from a new check in pip version 20.3.X and higher, likely related to the new dependency resolver. I can reproduce this error with pip version 20.3.3, but the package installs correctly with pip version 20.2.4.

    The easiest way to proceed would probably be to first downgrade pip; i.e.

    pip install pip==20.2.4
    

    and then proceed with your jaxlib install.