pythonpytorchreinforcement-learningtransformer-modelrllib

I keep getting this error, cuda available 'RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu


I'm training a transformer model using RLlib's PPO algorithm, but I encounter a device mismatch error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Despite moving all model components to the GPU with to(self.device), the error persists. CUDA is available, and the model is intended to run on the GPU.

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class SimpleTransformer(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        
        # Configuration
        custom_config = model_config["custom_model_config"]
        self.input_dim = 76
        self.seq_len = custom_config["seq_len"]
        self.embed_size = custom_config["embed_size"]
        self.nheads = custom_config["nhead"]
        self.nlayers = custom_config["nlayers"]
        self.dropout = custom_config["dropout"]
        self.values_out = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Input layer
        self.input_embed = nn.Linear(self.input_dim, self.embed_size).to(self.device)
        
        # Positional encoding
        self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size).to(self.device)
        
        # Transformer
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.embed_size,
                nhead=self.nheads,
                dropout=self.dropout,
                activation='gelu',
                device=self.device), 
            num_layers=self.nlayers
        )
        
        # Policy and value heads
        self.policy_head = nn.Sequential(
            nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
            nn.ReLU(),
            nn.Linear(64, num_outputs) # Action space size
        ).to(self.device)

        self.value_head = nn.Sequential(
            nn.Linear(self.embed_size + 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ).to(self.device)

    def forward(self, input_dict, state, seq_len):
        # Process input
        x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
        dynamic_features = x[:, -1, 2:4].clone().to(self.device)
        x = self.input_embed(x)
        
        position = torch.arange(0, self.seq_len).unsqueeze(0).expand(x.size(0), -1).to(self.device)
        x = x + self.pos_encoding(position)
        
        transformer_out = self.transformer(x)
        last_out = transformer_out[:, -1, :]
         
        combined = torch.cat((last_out, dynamic_features), dim=1)
        
        actions = self.policy_head(combined)
        self.values_out = self.value_head(combined).squeeze(1)
        
        return actions, state

Here is the full Error message:

Trial status: 1 ERROR
Current time: 2025-04-11 20:44:55. Total running time: 14s
Logical resource usage: 0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
╭──────────────────────────────────────╮
│ Trial name                  status   │
├──────────────────────────────────────┤
│ PPO_CryptoEnv_a50d0_00000   ERROR    │
╰──────────────────────────────────────╯

Number of errored trials: 1
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name                    # failures   error file                                                                                                                                                                                                      │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ PPO_CryptoEnv_a50d0_00000              1   C:/Users/tmpou/AppData/Local/Temp/ray/session_2025-04-11_20-44-35_479257_23712/artifacts/2025-04-11_20-44-40/PPO_2025-04-11_20-44-40/driver_artifacts/PPO_CryptoEnv_a50d0_00000_0_2025-04-11_20-44-40/error.txt │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Traceback (most recent call last):
  File "C:\Users\tmpou\Developer\MSc AI\Deep Learning and Multi-media data\crypto_rl_bot\train.py", line 14, in <module>
    tune.run(
  File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\tune\tune.py", line 1042, in run
    raise TuneError("Trials did not complete", incomplete_trials)
ray.tune.error.TuneError: ('Trials did not complete', [PPO_CryptoEnv_a50d0_00000])
(PPO pid=31224) 2025-04-11 20:44:55,030 ERROR actor_manager.py:517 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001D47885E850>)
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224)     return method(__ray_actor, *args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=31224)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=31224)     self._build_policy_map(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=31224)     new_policy = create_policy_for_framework(
(PPO pid=31224)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=31224)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=31224)     self._initialize_loss_from_dummy_batch()
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(PPO pid=31224)     self.loss(self.model, self.dist_class, train_batch)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=31224)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=31224)     return self.dist.log_prob(actions)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(PPO pid=31224)     return log_pmf.gather(-1, value).squeeze(-1)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(PPO pid=31224) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=31224, ip=127.0.0.1, actor_id=f5d50e01341cb51a747d8a3e01000000, repr=PPO)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 229, in _setup
(PPO pid=31224)     self.add_workers(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 682, in add_workers
(PPO pid=31224)     raise result.get()
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\actor_manager.py", line 497, in _fetch_result
(PPO pid=31224)     result = ray.get(r)
(PPO pid=31224)              ^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\auto_init_hook.py", line 21, in auto_init_wrapper
(PPO pid=31224)     return fn(*args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\client_mode_hook.py", line 103, in wrapper
(PPO pid=31224)     return func(*args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\worker.py", line 2667, in get
(PPO pid=31224)     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
(PPO pid=31224)                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\worker.py", line 866, in get_objects
(PPO pid=31224)     raise value
(PPO pid=31224) ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001D47885E850>)
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224)     return method(__ray_actor, *args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=31224)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=31224)     self._build_policy_map(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=31224)     new_policy = create_policy_for_framework(
(PPO pid=31224)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=31224)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=31224)     self._initialize_loss_from_dummy_batch()
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(PPO pid=31224)     self.loss(self.model, self.dist_class, train_batch)
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=31224)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=31224)     return self.dist.log_prob(actions)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(PPO pid=31224)     return log_pmf.gather(-1, value).squeeze(-1)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(PPO pid=31224)
(PPO pid=31224) During handling of the above exception, another exception occurred:
(PPO pid=31224)
(PPO pid=31224) ray::PPO.__init__() (pid=31224, ip=127.0.0.1, actor_id=f5d50e01341cb51a747d8a3e01000000, repr=PPO)
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1883, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1984, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224)     return method(__ray_actor, *args, **kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 533, in __init__
(PPO pid=31224)     super().__init__(
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\tune\trainable\trainable.py", line 161, in __init__
(PPO pid=31224)     self.setup(copy.deepcopy(self.config))
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224)     return method(self, *_args, **_kwargs)
(PPO pid=31224)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 631, in setup
(PPO pid=31224)     self.workers = WorkerSet(
(PPO pid=31224)                    ^^^^^^^^^^
(PPO pid=31224)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 181, in __init__
(PPO pid=31224)     raise e.args[0].args[2]
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(RolloutWorker pid=3964) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001D47885E850>)
(RolloutWorker pid=3964)   File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(RolloutWorker pid=3964)   File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(RolloutWorker pid=3964)     return method(__ray_actor, *args, **kwargs)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RolloutWorker pid=3964)     return method(self, *_args, **_kwargs) [repeated 3x across cluster]
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 3x across cluster]
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__ [repeated 2x across cluster]
(RolloutWorker pid=3964)     self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(RolloutWorker pid=3964)     self._build_policy_map(
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(RolloutWorker pid=3964)     new_policy = create_policy_for_framework(
(RolloutWorker pid=3964)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(RolloutWorker pid=3964)     return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)     self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=3964)     self.loss(self.model, self.dist_class, train_batch)
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(RolloutWorker pid=3964)     curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(RolloutWorker pid=3964)     return self.dist.log_prob(actions)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964)   File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(RolloutWorker pid=3964)     return log_pmf.gather(-1, value).squeeze(-1)
(RolloutWorker pid=3964)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)

Solution

  • To resolve the device mismatch error, you should let RLlib and PyTorch manage device placement automatically.

    It's also important to override the forward and value_function methods, as suggested by @Marzi Heifari.

    Here is the modified version:

    import torch
    import torch.nn as nn
    from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
    from ray.rllib.utils.annotations import override, DeveloperAPI
    from ray.rllib.models.modelv2 import ModelV2
    
    @DeveloperAPI
    class SimpleTransformer(TorchModelV2, nn.Module):
        def __init__(self, obs_space, action_space, num_outputs, model_config, name):
            TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
            nn.Module.__init__(self)
            
            # Configuration
            custom_config = model_config["custom_model_config"]
            self.input_dim = 76
            self.seq_len = custom_config["seq_len"]
            self.embed_size = custom_config["embed_size"]
            self.nheads = custom_config["nhead"]
            self.nlayers = custom_config["nlayers"]
            self.dropout = custom_config["dropout"]
            self.values_out = None
            self.device = None
    
            # Input layer
            self.input_embed = nn.Linear(self.input_dim, self.embed_size)
            
            # Positional encoding
            self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size)
            
            # Transformer
            self.transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=self.embed_size,
                    nhead=self.nheads,
                    dropout=self.dropout,
                    activation='gelu'), 
                num_layers=self.nlayers
            )
            
            # Policy and value heads
            self.policy_head = nn.Sequential(
                nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
                nn.ReLU(),
                nn.Linear(64, num_outputs) # Action space size
            )
    
            self.value_head = nn.Sequential(
                nn.Linear(self.embed_size + 2, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
    
        @override(ModelV2)
        def forward(self, input_dict, state, seq_lens):
            self.device = input_dict["obs"].device
            x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
            dynamic_features = x[:, -1, 2:4].clone()
        
            x = self.input_embed(x)
            position = torch.arange(0, self.seq_len, device=self.device).unsqueeze(0).expand(x.size(0), -1)
            x = x + self.pos_encoding(position)
        
            transformer_out = self.transformer(x)
            last_out = transformer_out[:, -1, :]
            combined = torch.cat((last_out, dynamic_features), dim=1)
        
            logits = self.policy_head(combined)
            self.values_out = self.value_head(combined).squeeze(1)
        
            return logits, state
        
        @override(ModelV2)
        def value_function(self):
            return self.values_out.to(self.device)