pythonindexingpoint-cloudsjaxxla

JAX Point Cloud Processing: Slow index_points_3d operation causing extreme XLA fusion loops in backpropagation


I'm trying to use JAX for implementing point cloud processing. However, I found that training becomes extremely slow due to my implementation of the following index_points_3d operation, which performs selection of features based on 3D indices.

Here's my current implementation:

import jax
import jax.numpy as jnp

@jax.jit
def index_points_3d(features, indices):
    """
    Args:
        features: shape (B, N, C)
        indices: shape (B, npoint, nsample)
    
    Returns:
        shape (B, npoint, nsample, C)
    """
    features_expanded = features[..., None, :]
    idx_expanded = indices[..., None]
    return jnp.take_along_axis(features_expanded, idx_expanded, axis=1)

When I traced the profiler, I found that this operation triggers extreme repetitions of loop_dynamic_update_slice_fusion, loop_add_fusion, input_reduce_fusion, and loop_select_fusion in the backpropagation stage as in following.

enter image description here

The forward pass is not a problem since the learning went fast when I stopped the gradient of the output features.

I've tried different implementations such as using vmap on the batch dimension, but failed to achieve any performance gains.

I'm not deeply familiar with JAX's low-level operations, so I'm unsure if this is a fundamental limitation of JAX/XLA or if there's a more efficient approach. Any help or guidance on optimizing this operation would be greatly appreciated!


Solution

  • Thanks to jakevdp's comment, I got a significant speedup using one-hot matrix multiplication. I changed to the following code:

    @jax.jit
    def index_points_3d(features, indices):
        """
        Args:
            features: shape (B, N, C)
            indices: shape (B, npoint, nsample)
        
        Returns:
            shape (B, npoint, nsample, C)
        """
        B, N, C = features.shape
        _, S, K = indices.shape
        one_hot = jax.nn.one_hot(indices, num_classes=N, dtype=features.dtype)
        return jnp.einsum('bskn,bnc->bskc', one_hot, features)