I have a batch of images and a batch of indices (x, y) for each image. The indices are different for each image, so I cant use simple indexing. What is the best or fastest way to get another batch with the colors of the selected pixels per image?
n_images = 4
width = 100
height = 100
channels = 3
n_samples = 30
images = torch.rand((n_images, height, width, channels))
indices = (torch.rand((n_images, n_samples, 2)) * width).to(torch.int32)
# preferred function
# result = images[indices]
# with result.shape = (n_images, n_samples, 3)
# I just found this solution but I would rather like to call a general torch function
xs = indices.reshape((-1, 2))[:, 0]
ys = indices.reshape((-1, 2))[:, 1]
ix = torch.arange(n_images, dtype=torch.int32)
ix = ix[..., None].expand((-1, n_samples)).flatten()
result = images[ix, ys, xs].reshape((n_images, n_samples, 3))
You can use your indices
tensor directly, you just need another tensor for the batch indexing:
n_images = 4
width = 100
height = 100
channels = 3
n_samples = 30
images = torch.rand((n_images, height, width, channels))
indices = (torch.rand((n_images, n_samples, 2)) * width).to(torch.int32)
batch_indices = torch.arange(n_images).view(-1, 1).expand(-1, n_samples)
result = images[batch_indices, indices[..., 1], indices[..., 0]]
This follows your convention of images[ix, ys, xs]
where the ys
index the height dimension of the tensor and the xs
index the width