pythonpipcudanvidiajax

How to correctly install JAX with CUDA on Linux when `jax[cuda12_pip]` consistently falls back to the CPU version?


I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, always resulting in JAX falling back to the CPU.

I am looking for a definitive, foolproof set of commands to get a working GPU installation.

System Specifications:


What I Have Tried

I have meticulously created fresh conda environments for each attempt to ensure there are no conflicts.

Attempt #1: The Standard Recommended Method

This is the official recommended command.

conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

Attempt #2: The Direct URL Method

This is the expert workaround to force installation of a specific GPU wheel.

# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy<2.0"
ERROR: HTTP error 404 while getting https://.../jaxlib-0.4.23...whl

This shows that relying on specific, old URLs is not a stable solution.

Attempt #3: The Staged Plugin Method

This involves installing the CUDA plugin first, then JAX.

# In a clean environment...
pip install --upgrade "jax-cuda12-plugin"
pip install jax

My Question

I am completely stuck.

Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?


Solution

  • Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?

    I would stick with the standard installation method – in your case, this is the relevant warning:

    WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
    

    You're using an old installation command: the cuda12-pip extra was removed in JAX v0.6.0. JAX's current installation instructions recommend this:

    pip install -U "jax[cuda12]"
    

    If you use this command, you should get the correct GPU-specific installation of the latest JAX version compatible with your platform and the Python version you're using (JAX v0.6.2 in the case of Linux / Python 3.10).


    I noticed in some of your examples you're pinning old JAX versions (v0.4.23). If you're interested in installing jaxlib wheels for older JAX versions, there are some tips at https://docs.jax.dev/en/latest/installation.html#installing-older-jaxlib-wheels. If you're specifically having issues related to installing older jaxlib versions, I'd suggest posting another question on that topic, being very clear about which JAX version you're trying to install.