pytorchnumpy-indexing

Batched index_fill in PyTorch


I have an index tensor of size (2, 3):

>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
        [3., 4., 7.]])

And a value tensor of size (2, 8):

>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

I want to set the element in value to 1 by the index along dim=-1.** The output should be like:

>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

I tried value[range(2), index] = 1 but it triggers an error. I also tried torch.index_fill but it doesn't accept batched indices. torch.scatter requires creating an extra tensor of size 2*8 full of 1, which consumes unnecessary memory and time.


Solution

  • You can actually use torch.Tensor.scatter_ by setting the value (int) option instead of the src option (Tensor).

    >>> value.scatter_(dim=-1, index=index.long(), value=1)
    
    >>> value
    tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
            [0., 0., 0., 1., 1., 0., 0., 1.]])
    

    Make sure the index is of type int64 though.