pythonpytorch

PyTorch scatter max for sparse tensors?


I have the following PyTorch code

value_tensor = torch.sparse_coo_tensor(indices=query_indices.t(), values=values, size=(num_lines, img_size, img_size)).to(device=device)
value_tensor = value_tensor.to_dense()
indices = torch.arange(0, img_size * img_size).repeat(len(lines)).to(device=device)
line_tensor_flat = value_tensor.flatten()
img, _ = scatter_max(line_tensor_flat, indices, dim=0)
img = torch.reshape(img, (img_size, img_size))

Note the line: value_tensor = value_tensor.to_dense(), this is unsurprisingly slow.

However, I cannot figure out how to obtain the same results with a sparse tensor. The function in question calls reshape which is not available on sparse tensors. I'm using Scatter Max but opened to using anything that works.


Solution

  • You should be able to directly use scatter_max on the sparse tensor if you keep the indices that you pass to scatter_max also sparse (i.e, only the non-zero ones).

    Consider this example

    query_indices = torch.tensor([
        [0, 0, 0, 1, 1, 1],
        [0, 1, 2, 0, 1, 2],
        [0, 1, 0, 0, 1, 0]
    ])
    
    values = torch.tensor([1, 2, 3, 4, 5, 6])
    num_lines = 2
    img_size = 3
    
    value_tensor = torch.sparse_coo_tensor(
        indices=query_indices,
        values=values,
        size=(num_lines, img_size, img_size)
    )
    
    # need to coalesce because for some reason sparse_coo_tensor doesn't guarantee uniqueness of indices
    value_tensor = value_tensor.coalesce()
    
    

    Then, compute flat_indices as a sparse tensor containing just the non-zero 1-d indices (2-d indices are converted to 1-d indices similar to your arange)

    indices = value_tensor.indices()
    values = value_tensor.values()
    
    batch_indices = indices[0]        # "line" (in your terminology) indices
    row_indices = indices[1]
    col_indices = indices[2]
    flat_indices = row_indices * img_size + col_indices
    
    
    

    You can use flat_indices to scatter_max

    flattened_result, _ = scatter_max(
        values, flat_indices, dim=0, dim_size=img_size * img_size
    )
    
    per_line_max = flattened_result.reshape(img_size, img_size)
    
    
    indices
    
    tensor([[0, 0, 0, 1, 1, 1],
            [0, 1, 2, 0, 1, 2],
            [0, 1, 0, 0, 1, 0]])
    
    
    values
    
    tensor([1, 2, 3, 4, 5, 6])
    
    
    flat_indices
    
    tensor([0, 4, 6, 0, 4, 6])
    
    per_line_max
    
    tensor([[4, 0, 0],
            [0, 5, 0],
            [6, 0, 0]])
    
    

    The output I get is the same as what I get from your code.