pythonpytorchtensortpu

Map a list of list to TPU tensor


I have this code to deconstruct a list of list to different tensors.

token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.LongTensor, zip(*batch))

If I want to create these tensors on gpu, I can use below code:

token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.cuda.LongTensor, zip(*batch))

But now I want to create all these on TPU, what should I do? Is there anything like below?

token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.xla.LongTensor, zip(*batch))

Solution

  • You can use xla device following the guide here.

    You can select the device and pass it to your function like this:

    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(lambda x: torch.Tensor(x).to(device).long(), zip(*batch))
    

    You can even parametrize the device variable, torch.device("cuda" if torch.cuda.is_available() else "cpu") can be used to select between cuda and cpu.