pythonpytorchtensoradvanced-indexing

Best way to convert a tensor from a condensed representation


I have a Tensor that is in a condensed format representing a sparse 3-D matrix. I need to convert it to a normal matrix (the one that it is actually representing). So, in my case, each row of any 2-D slice of my matrix can only contain one non-zero element. As data, then, I have for each of these rows, the value, and the index where it appears. For example, the tensor

inp = torch.tensor([[ 1,  2],
 [ 3,  4],
 [-1,  0],
 [45,  1]])

represents a 4x5 matrix (first dimension comes from the first dimension of the tensor, second comes from the metadata) A, where A[0][2] = 1, A[1][4] = 3, A[2][0] = -1, A[3][1] = 45.

This is just one 2-D slice of my Matrix, and I have a variable number of these. I was able to do this for a 2-D slice as described above in the following way using sparse_coo_tensor:

>>> torch.sparse_coo_tensor(torch.stack([torch.arange(0, 4), inp.t()[1]]), inp.t()[0], [4,5]).to_dense()
tensor([[ 0,  0,  1,  0,  0],
        [ 0,  0,  0,  0,  3],
        [-1,  0,  0,  0,  0],
        [ 0, 45,  0,  0,  0]])

Is this the best way to accomplish this? Is there a simpler, more readable alternative? How do I extend this to a 3-D matrix without looping? For a 3-D matrix, you can imagine the input to be something like

inp_list = torch.stack([inp, inp, inp, inp])

and the desired output would be the above output stacked 4 times.

I feel like I should be able to do something if I create an index array correctly, but I cannot think of a way to do this without using some kind of looping.


Solution

  • OK, after a lot of experiments with different types of indexing, I got this to work. Turns out, the answer was in Advanced Indexing. Unfortunately, PyTorch documentation doesn't go in the details of Advanced Indexing. Here is a link for it in the Numpy documentation.

    For the problem described above, this command did the trick:

    >>> k_lst = torch.zeros([4,4,5])
    >>> k_lst[torch.arange(4).unsqueeze(1), torch.arange(4), inp_list[:,:,1]] = inp_list[:,:,0].float()
    >>> k_lst
    tensor([[[ 0.,  0.,  1.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  3.],
         [-1.,  0.,  0.,  0.,  0.],
         [ 0., 45.,  0.,  0.,  0.]],
        [[ 0.,  0.,  1.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  3.],
         [-1.,  0.,  0.,  0.,  0.],
         [ 0., 45.,  0.,  0.,  0.]],
        [[ 0.,  0.,  1.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  3.],
         [-1.,  0.,  0.,  0.,  0.],
         [ 0., 45.,  0.,  0.,  0.]],
        [[ 0.,  0.,  1.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  3.],
         [-1.,  0.,  0.,  0.,  0.],
         [ 0., 45.,  0.,  0.,  0.]]])
    

    Which is exactly what I wanted.

    I learned quite a few things searching for this, and I want to share this for anyone who stumbles on this question. So, why does this work? The answer lies in the way Broadcasting works. If you look at the shapes of the different index tensors involved, you'd see that they are (of necessity) broadcastable.

    >>> torch.arange(4).unsqueeze(1).shape, torch.arange(4).shape, inp_list[:,:,1].shape
    (torch.Size([4, 1]), torch.Size([4]), torch.Size([4, 4]))
    

    Clearly, to access an element of a 3-D tensor such as k_lst here, we need 3 indexes - one for each dimension. If you give 3 tensors of same shapes to the [] operator, it can get a bunch of legal indexes by matching corresponding elements from the 3 tensors.

    If the 3 tensors are of different shapes, but broadcastable (as is the case here), it copies the relevant rows/columns of the lacking tensors the requisite number of times to get tensors with the same shapes.

    Ultimately, in my case, if we go into how the different values got assigned, this would be equivalent to doing

    k_lst[0,0,inp_list[0,0,1]] = inp_list[0,0,0].float()
    k_lst[0,1,inp_list[0,1,1]] = inp_list[0,1,0].float()
    k_lst[0,2,inp_list[0,2,1]] = inp_list[0,2,0].float()
    k_lst[0,3,inp_list[0,3,1]] = inp_list[0,3,0].float()
    k_lst[1,0,inp_list[1,0,1]] = inp_list[1,0,0].float()
    k_lst[1,1,inp_list[1,1,1]] = inp_list[1,1,0].float()
    .
    .
    .
    k_lst[3,3,inp_list[3,3,1]] = inp_list[3,3,0].float()
    

    This format reminds me of torch.Tensor.scatter(), but if it can be used to solve this problem, I haven't figured out how yet.