I have clustered the pixels of an image into clusters of different sizes and shapes. I want to max pool each cluster as fast as possible because the max pooling happens in one layer of my CNN.
To clarify: Input is a batch of images with the following shape [batch_size, height of image, width of image, number of channels]. I have clustered each image before I start training my CNN. So for each image I have a ndarray of labels with shape [height of image, width of image].
How can I max pool over all pixels of an image that have the same label for all labels? I understand how to do it with a of for loop but that is painstakingly slow. I am searching for a fast solution that ideally can max pool over every cluster of each image in less than a second.
For implementation, I use Python3.7 and PyTorch.
I figured it out. torch_scatter. scatter_max(img, cluster_labels) outputs the max element from each cluster and removes the for loop from my code.