Is there a way to get the histograms of torch tensors in batches?
For Example:
x is a tensor of shape (64, 224, 224)
# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)
As suggested in the Pytorch Issues#99719, you can do this by torch.Tensor.scatter_add_
. scatter_add_
is more memory efficient than torch.nn.functional.one_hot
.
Similar to @user118967's answer:
# https://github.com/pytorch/pytorch/issues/99719#issuecomment-1664135524
def batch_histogram(data_tensor, num_classes=-1):
"""
Computes histograms, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
maxd = data_tensor.max()
nc = (maxd+1) if num_classes <= 0 else num_classes
hist = torch.zeros((*data_tensor.shape[:-1], nc), dtype=data_tensor.dtype, device=data_tensor.device)
ones = torch.tensor(1, dtype=hist.dtype, device=hist.device).expand(data_tensor.shape)
hist.scatter_add_(-1, ((data_tensor * nc) // (maxd+1)).long(), ones)
return hist
with the test cases in Google colab here