pythondynamic-memory-allocationjax

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed


I am trying to run multiple sbx programs (that use JAX) concurrently using joblib. Here is my program -

'''
For installation please do -
pip install gym
pip install sbx-rl
pip install mujoco
pip install shimmy
'''
from joblib import Parallel, delayed

import gym
from sbx import SAC

# from stable_baselines3 import SAC
def train():


    env = gym.make("Humanoid-v4")

    model = SAC("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=7e5, progress_bar=True)

def train_model():

    train()



if __name__ == '__main__':
    Parallel(n_jobs=10)(delayed(train)() for i in range(3))


This is the error that I am getting -

/home/dgthomas/.local/lib/python3.10/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
  warnings.warn(
/home/dgthomas/.local/lib/python3.10/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
  warnings.warn(
/home/dgthomas/.local/lib/python3.10/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
  warnings.warn(
2024-01-30 11:19:12.354168: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory
2024-01-30 11:19:12.354264: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory; current tracing scope: custom-call.11; current profiling annotation: XlaModule:#prefix=jit(_threefry_split)/jit(main),hlo_module=jit__threefry_split,program_id=2#.
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 463, in _process_worker
    r = call_item()
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 291, in __call__
    return self.fn(*self.args, **self.kwargs)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 589, in __call__
    return [func(*args, **kwargs)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 589, in <listcomp>
    return [func(*args, **kwargs)
  File "/work/LAS/usr/tbd/5_test.py", line 23, in my_func
    model = SAC("MlpPolicy", env,verbose=0)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/sbx/sac/sac.py", line 109, in __init__
    self._setup_model()
  File "/home/dgthomas/.local/lib/python3.10/site-packages/sbx/sac/sac.py", line 126, in _setup_model
    self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/sbx/sac/policies.py", line 143, in build
    key, actor_key, qf_key, dropout_key = jax.random.split(key, 4)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/random.py", line 303, in split
    return _return_prng_keys(wrapped, _split(typed_key, num))
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/random.py", line 289, in _split
    return prng.random_split(key, shape=shape)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 769, in random_split
    return random_split_p.bind(keys, shape=shape)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 781, in random_split_impl
    base_arr = random_split_impl_base(
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 787, in random_split_impl_base
    return split(base_arr)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 786, in <lambda>
    split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
  File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 1291, in threefry_split
    return _threefry_split(key, shape)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory; current tracing scope: custom-call.11; current profiling annotation: XlaModule:#prefix=jit(_threefry_split)/jit(main),hlo_module=jit__threefry_split,program_id=2#.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/work/LAS/usr/tbd/5_test.py", line 27, in <module>
    Parallel(n_jobs=3)(delayed(my_func)() for i in range(3))
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory; current tracing scope: custom-call.11; current profiling annotation: XlaModule:#prefix=jit(_threefry_split)/jit(main),hlo_module=jit__threefry_split,program_id=2#.


I am using a 40 GB GPU (a100-pcie). Therefore I doubt that my GPU is running out of memory. Please let me know if any clarification is needed.

Edit 1: This is how I call my program - export XLA_PYTHON_CLIENT_PREALLOCATE=false && python 5_test.py (The name of my program is 5_test.py)


Solution

  • It appears you are using multiple processes targeting the same GPU. In each process, JAX will attempt to reserve 75% of the available GPU memory (see GPU memory allocation), so attempting this with two or more processes will exhaust the available memory.

    You could fix this by turning off pre-allocation as mentioned in that doc, by setting the environment variables XLA_PYTHON_CLIENT_PREALLOCATE=false or XLA_PYTHON_CLIENT_MEM_FRACTION=.XX (with .XX set to .08 or something suitable), but I suspect the end result will be less efficient than if you had just run your full program from a single JAX process: multiple host processes targeting a single GPU device concurrently will just compete with each other for resources and lead to suboptimal results.