cudagpucondanvidiamamba

Conda CUDA12 incompatibility


I'm trying to install the pl_upgrade branch from Openfold to install on a GCP VM with an NVIDIA L4 mounted.

Following the instructions, i have installed mambaforga, and create an environment using their environment.yml file. However, there appears to be a mismatch between the cuda version of torch, all the libraries, and the cudatoolkit version installed.

cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.4.127                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
cudatoolkit               11.8.0              h4ba93d1_13    conda-forge
torchtriton               2.1.0                     py310    pytorch
pytorch                   2.1.2           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-lightning         2.2.2              pyhd8ed1ab_0    conda-forge
[not full `conda list` output]

I also already had CUDA 12 installed on the system with NVIDIA driver version 535.86.10 also installed.

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

I've noticed that cudatoolkit on conda-forge or through the nvidia channel only has cudatoolkit up to v11.8, which seemingly is being replaced by cuda-x.y.0::cuda-toolkit, including v12 onwards. I tried adding nvidia/label/cuda-12.1.0::cuda-toolkit to the environment.yml, and it's started installing lots of cuda 12.4 libraries alongside a lot of 12.1 including nvcc 12.4.

I'm confused how to resolve this. I just want only cuda 12.1 to install.


Solution

  • If anyone else is having this issue, especially with an RTX4090, the pl_upgrades branch (hash 3bec3e9) worked well enough. Reverting to flash-attn v2.0 helped get past one of the unit-test errors.

    But installing cuda 12.1 from runfile rather than conda was how I ended up getting a stable environment.