I have the following PyTorch code
value_tensor = torch.sparse_coo_tensor(indices=query_indices.t(), values=values, size=(num_lines, img_size, img_size)).to(device=device)
value_tensor = value_tensor.to_dense()
indices = torch.arange(0, img_size * img_size).repeat(len(lines)).to(device=device)
line_tensor_flat = value_tensor.flatten()
img, _ = scatter_max(line_tensor_flat, indices, dim=0)
img = torch.reshape(img, (img_size, img_size))
Note the line: value_tensor = value_tensor.to_dense()
, this is unsurprisingly slow.
However, I cannot figure out how to obtain the same results with a sparse tensor. The function in question calls reshape
which is not available on sparse tensors. I'm using Scatter Max but opened to using anything that works.
You should be able to directly use scatter_max
on the sparse tensor if you keep the indices that you pass to scatter_max
also sparse (i.e, only the non-zero ones).
Consider this example
query_indices = torch.tensor([
[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2],
[0, 1, 0, 0, 1, 0]
])
values = torch.tensor([1, 2, 3, 4, 5, 6])
num_lines = 2
img_size = 3
value_tensor = torch.sparse_coo_tensor(
indices=query_indices,
values=values,
size=(num_lines, img_size, img_size)
)
# need to coalesce because for some reason sparse_coo_tensor doesn't guarantee uniqueness of indices
value_tensor = value_tensor.coalesce()
Then, compute flat_indices
as a sparse tensor containing just the non-zero 1-d indices (2-d indices are converted to 1-d indices similar to your arange
)
indices = value_tensor.indices()
values = value_tensor.values()
batch_indices = indices[0] # "line" (in your terminology) indices
row_indices = indices[1]
col_indices = indices[2]
flat_indices = row_indices * img_size + col_indices
You can use flat_indices
to scatter_max
flattened_result, _ = scatter_max(
values, flat_indices, dim=0, dim_size=img_size * img_size
)
per_line_max = flattened_result.reshape(img_size, img_size)
indices
tensor([[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2],
[0, 1, 0, 0, 1, 0]])
values
tensor([1, 2, 3, 4, 5, 6])
flat_indices
tensor([0, 4, 6, 0, 4, 6])
per_line_max
tensor([[4, 0, 0],
[0, 5, 0],
[6, 0, 0]])
The output I get is the same as what I get from your code.