I have this code to deconstruct a list of list to different tensors.
token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.LongTensor, zip(*batch))
If I want to create these tensors on gpu, I can use below code:
token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.cuda.LongTensor, zip(*batch))
But now I want to create all these on TPU, what should I do? Is there anything like below?
token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.xla.LongTensor, zip(*batch))
You can use xla device following the guide here.
You can select the device and pass it to your function like this:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(lambda x: torch.Tensor(x).to(device).long(), zip(*batch))
You can even parametrize the device variable, torch.device("cuda" if torch.cuda.is_available() else "cpu")
can be used to select between cuda
and cpu
.