I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, always resulting in JAX falling back to the CPU.
I am looking for a definitive, foolproof set of commands to get a working GPU installation.
System Specifications:
I have meticulously created fresh conda environments for each attempt to ensure there are no conflicts.
Attempt #1: The Standard Recommended Method
This is the official recommended command.
conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
jaxlib wheel with CUDA libraries included should be downloaded and installed.pip consistently ignores the [cuda12_pip] directive, downloads the small CPU version of jaxlib (89.9 MB), and gives a warning. The verification command confirms this failure:WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
Attempt #2: The Direct URL Method
This is the expert workaround to force installation of a specific GPU wheel.
# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy<2.0"
jaxlib is installed.ERROR: HTTP error 404 while getting https://.../jaxlib-0.4.23...whl
This shows that relying on specific, old URLs is not a stable solution.
Attempt #3: The Staged Plugin Method
This involves installing the CUDA plugin first, then JAX.
# In a clean environment...
pip install --upgrade "jax-cuda12-plugin"
pip install jax
jaxlib provided by the plugin.jax re-installs the wrong CPU version of jaxlib over the plugin's libraries, leading back to the same "fallback to CPU" problem as Attempt #1.I am completely stuck.
Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?
Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?
I would stick with the standard installation method – in your case, this is the relevant warning:
WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
You're using an old installation command: the cuda12-pip extra was removed in JAX v0.6.0. JAX's current installation instructions recommend this:
pip install -U "jax[cuda12]"
If you use this command, you should get the correct GPU-specific installation of the latest JAX version compatible with your platform and the Python version you're using (JAX v0.6.2 in the case of Linux / Python 3.10).
I noticed in some of your examples you're pinning old JAX versions (v0.4.23). If you're interested in installing jaxlib wheels for older JAX versions, there are some tips at https://docs.jax.dev/en/latest/installation.html#installing-older-jaxlib-wheels. If you're specifically having issues related to installing older jaxlib versions, I'd suggest posting another question on that topic, being very clear about which JAX version you're trying to install.