I'm training a transformer model using RLlib's PPO algorithm, but I encounter a device mismatch error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Despite moving all model components to the GPU with to(self.device)
, the error persists. CUDA is available, and the model is intended to run on the GPU.
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class SimpleTransformer(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
# Configuration
custom_config = model_config["custom_model_config"]
self.input_dim = 76
self.seq_len = custom_config["seq_len"]
self.embed_size = custom_config["embed_size"]
self.nheads = custom_config["nhead"]
self.nlayers = custom_config["nlayers"]
self.dropout = custom_config["dropout"]
self.values_out = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Input layer
self.input_embed = nn.Linear(self.input_dim, self.embed_size).to(self.device)
# Positional encoding
self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size).to(self.device)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.embed_size,
nhead=self.nheads,
dropout=self.dropout,
activation='gelu',
device=self.device),
num_layers=self.nlayers
)
# Policy and value heads
self.policy_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
nn.ReLU(),
nn.Linear(64, num_outputs) # Action space size
).to(self.device)
self.value_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64),
nn.ReLU(),
nn.Linear(64, 1)
).to(self.device)
def forward(self, input_dict, state, seq_len):
# Process input
x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
dynamic_features = x[:, -1, 2:4].clone().to(self.device)
x = self.input_embed(x)
position = torch.arange(0, self.seq_len).unsqueeze(0).expand(x.size(0), -1).to(self.device)
x = x + self.pos_encoding(position)
transformer_out = self.transformer(x)
last_out = transformer_out[:, -1, :]
combined = torch.cat((last_out, dynamic_features), dim=1)
actions = self.policy_head(combined)
self.values_out = self.value_head(combined).squeeze(1)
return actions, state
Here is the full Error message:
Trial status: 1 ERROR
Current time: 2025-04-11 20:44:55. Total running time: 14s
Logical resource usage: 0/12 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)
╭──────────────────────────────────────╮
│ Trial name status │
├──────────────────────────────────────┤
│ PPO_CryptoEnv_a50d0_00000 ERROR │
╰──────────────────────────────────────╯
Number of errored trials: 1
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name # failures error file │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ PPO_CryptoEnv_a50d0_00000 1 C:/Users/tmpou/AppData/Local/Temp/ray/session_2025-04-11_20-44-35_479257_23712/artifacts/2025-04-11_20-44-40/PPO_2025-04-11_20-44-40/driver_artifacts/PPO_CryptoEnv_a50d0_00000_0_2025-04-11_20-44-40/error.txt │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
Traceback (most recent call last):
File "C:\Users\tmpou\Developer\MSc AI\Deep Learning and Multi-media data\crypto_rl_bot\train.py", line 14, in <module>
tune.run(
File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\tune\tune.py", line 1042, in run
raise TuneError("Trials did not complete", incomplete_trials)
ray.tune.error.TuneError: ('Trials did not complete', [PPO_CryptoEnv_a50d0_00000])
(PPO pid=31224) 2025-04-11 20:44:55,030 ERROR actor_manager.py:517 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001D47885E850>)
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224) return method(__ray_actor, *args, **kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=31224) self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=31224) self._build_policy_map(
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=31224) new_policy = create_policy_for_framework(
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=31224) return policy_class(observation_space, action_space, merged_config)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=31224) self._initialize_loss_from_dummy_batch()
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(PPO pid=31224) self.loss(self.model, self.dist_class, train_batch)
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=31224) curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=31224) return self.dist.log_prob(actions)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(PPO pid=31224) return log_pmf.gather(-1, value).squeeze(-1)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(PPO pid=31224) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=31224, ip=127.0.0.1, actor_id=f5d50e01341cb51a747d8a3e01000000, repr=PPO)
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 229, in _setup
(PPO pid=31224) self.add_workers(
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 682, in add_workers
(PPO pid=31224) raise result.get()
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\actor_manager.py", line 497, in _fetch_result
(PPO pid=31224) result = ray.get(r)
(PPO pid=31224) ^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\auto_init_hook.py", line 21, in auto_init_wrapper
(PPO pid=31224) return fn(*args, **kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\client_mode_hook.py", line 103, in wrapper
(PPO pid=31224) return func(*args, **kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\worker.py", line 2667, in get
(PPO pid=31224) values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\worker.py", line 866, in get_objects
(PPO pid=31224) raise value
(PPO pid=31224) ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001D47885E850>)
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224) return method(__ray_actor, *args, **kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 535, in __init__
(PPO pid=31224) self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=31224) self._build_policy_map(
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=31224) new_policy = create_policy_for_framework(
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(PPO pid=31224) return policy_class(observation_space, action_space, merged_config)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__
(PPO pid=31224) self._initialize_loss_from_dummy_batch()
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(PPO pid=31224) self.loss(self.model, self.dist_class, train_batch)
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(PPO pid=31224) curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(PPO pid=31224) return self.dist.log_prob(actions)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(PPO pid=31224) return log_pmf.gather(-1, value).squeeze(-1)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(PPO pid=31224)
(PPO pid=31224) During handling of the above exception, another exception occurred:
(PPO pid=31224)
(PPO pid=31224) ray::PPO.__init__() (pid=31224, ip=127.0.0.1, actor_id=f5d50e01341cb51a747d8a3e01000000, repr=PPO)
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1883, in ray._raylet.execute_task
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1984, in ray._raylet.execute_task
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(PPO pid=31224) File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(PPO pid=31224) return method(__ray_actor, *args, **kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 533, in __init__
(PPO pid=31224) super().__init__(
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\tune\trainable\trainable.py", line 161, in __init__
(PPO pid=31224) self.setup(copy.deepcopy(self.config))
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span
(PPO pid=31224) return method(self, *_args, **_kwargs)
(PPO pid=31224) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 631, in setup
(PPO pid=31224) self.workers = WorkerSet(
(PPO pid=31224) ^^^^^^^^^^
(PPO pid=31224) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\worker_set.py", line 181, in __init__
(PPO pid=31224) raise e.args[0].args[2]
(PPO pid=31224) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
(RolloutWorker pid=3964) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=3964, ip=127.0.0.1, actor_id=b2fed95453b6755f07372fcb01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001D47885E850>)
(RolloutWorker pid=3964) File "python\ray\_raylet.pyx", line 1889, in ray._raylet.execute_task
(RolloutWorker pid=3964) File "python\ray\_raylet.pyx", line 1830, in ray._raylet.execute_task.function_executor
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\_private\function_manager.py", line 724, in actor_method_executor
(RolloutWorker pid=3964) return method(__ray_actor, *args, **kwargs)
(RolloutWorker pid=3964) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RolloutWorker pid=3964) return method(self, *_args, **_kwargs) [repeated 3x across cluster]
(RolloutWorker pid=3964) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 3x across cluster]
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 64, in __init__ [repeated 2x across cluster]
(RolloutWorker pid=3964) self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1743, in _update_policy_map
(RolloutWorker pid=3964) self._build_policy_map(
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 1854, in _build_policy_map
(RolloutWorker pid=3964) new_policy = create_policy_for_framework(
(RolloutWorker pid=3964) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\utils\policy.py", line 141, in create_policy_for_framework
(RolloutWorker pid=3964) return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=3964) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\policy\policy.py", line 1484, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=3964) self.loss(self.model, self.dist_class, train_batch)
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_torch_policy.py", line 112, in loss
(RolloutWorker pid=3964) curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\ray\rllib\models\torch\torch_action_dist.py", line 37, in logp
(RolloutWorker pid=3964) return self.dist.log_prob(actions)
(RolloutWorker pid=3964) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) File "C:\Users\tmpou\miniconda3\envs\crypto_bot\Lib\site-packages\torch\distributions\categorical.py", line 143, in log_prob
(RolloutWorker pid=3964) return log_pmf.gather(-1, value).squeeze(-1)
(RolloutWorker pid=3964) ^^^^^^^^^^^^^^^^^^^^^^^^^
(RolloutWorker pid=3964) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)
To resolve the device mismatch error, you should let RLlib and PyTorch manage device placement automatically.
Layers are no longer explicity moved to to(self.device)
during initialization
Used dynamic device detection of the input self.device = input_dict["obs"].device
Only inputs in the forward
method and values_out
in the value_function
are moved to the model's device manually.
It's also important to override the forward
and value_function
methods, as suggested by @Marzi Heifari.
Here is the modified version:
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.models.modelv2 import ModelV2
@DeveloperAPI
class SimpleTransformer(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
# Configuration
custom_config = model_config["custom_model_config"]
self.input_dim = 76
self.seq_len = custom_config["seq_len"]
self.embed_size = custom_config["embed_size"]
self.nheads = custom_config["nhead"]
self.nlayers = custom_config["nlayers"]
self.dropout = custom_config["dropout"]
self.values_out = None
self.device = None
# Input layer
self.input_embed = nn.Linear(self.input_dim, self.embed_size)
# Positional encoding
self.pos_encoding = nn.Embedding(self.seq_len, self.embed_size)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.embed_size,
nhead=self.nheads,
dropout=self.dropout,
activation='gelu'),
num_layers=self.nlayers
)
# Policy and value heads
self.policy_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64), # Add dynamic features (wallet balance, unrealized PnL)
nn.ReLU(),
nn.Linear(64, num_outputs) # Action space size
)
self.value_head = nn.Sequential(
nn.Linear(self.embed_size + 2, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
self.device = input_dict["obs"].device
x = input_dict["obs"].view(-1, self.seq_len, self.input_dim).to(self.device)
dynamic_features = x[:, -1, 2:4].clone()
x = self.input_embed(x)
position = torch.arange(0, self.seq_len, device=self.device).unsqueeze(0).expand(x.size(0), -1)
x = x + self.pos_encoding(position)
transformer_out = self.transformer(x)
last_out = transformer_out[:, -1, :]
combined = torch.cat((last_out, dynamic_features), dim=1)
logits = self.policy_head(combined)
self.values_out = self.value_head(combined).squeeze(1)
return logits, state
@override(ModelV2)
def value_function(self):
return self.values_out.to(self.device)