pythonjax

Jax vmap limit memory


I'm wondering if there is a good way to limit the memory usage for Jax's VMAP function? Equivalently, to vmap in batches at a time if that makes sense?

In my specific use case, I have a set of images and I'd like to calculate the affinity between each pair of images; so ~order((num_imgs)^2 * (img shape)) bytes of memory used all at once if I'm understanding vmap correctly (which gets huge since in my real example I have 10,000 100x100 images).

A basic example is:

def affininty_matrix_ex(n_arrays=10, img_size=5, key=jax.random.PRNGKey(0), gamma=jnp.array([0.5])):
    arr_of_imgs = jax.random.normal(jax.random.PRNGKey(0), (n_arrays, img_size, img_size))
    arr_of_indices = jnp.arange(n_arrays)
    inds_1, inds_2 = zip(*combinations(arr_of_indices, 2))
    v_cPA = jax.vmap(calcPairAffinity2, (0, 0, None, None), 0)
    affinities = v_cPA(jnp.array(inds_1), jnp.array(inds_2), arr_of_imgs, gamma)
    print()
    print(jax.make_jaxpr(v_cPA)(jnp.array(inds_1), jnp.array(inds_2), arr_of_imgs, gamma))
    
    affinities = affinities.reshape(-1)
    
    arr = jnp.zeros((n_arrays, n_arrays), dtype=jnp.float16)
    arr = arr.at[jnp.triu_indices(arr.shape[0], k=1)].set(affinities)
    arr = arr + arr.T
    arr = arr + jnp.identity(n_arrays, dtype=jnp.float16)
    
    return arr


def calcPairAffinity2(ind1, ind2, imgs, gamma):
    #Returns a jnp array of 1 float, jnp.sum adds all elements together
    image1, image2 = imgs[ind1], imgs[ind2]
    diff = jnp.sum(jnp.abs(image1 - image2))  
    normed_diff = diff / image1.size
    val = jnp.exp(-gamma*normed_diff)
    val = val.astype(jnp.float16)
    return val

I suppose I could just say something like "only feed into vmap X pairs at a time, and loop through n_chunks = n_arrays/X, appending each groups results to a list" but that doesn't seem to be ideal. My understanding is vmap does not like generators, not sure if that would be an alternative way around the issue.


Solution

  • Edit, Aug 13 2024

    As of JAX version 0.4.31, what you're asking for is possible using the batch_size argument of lax.map. For an iterable of size N, this will perform a scan with N // batch_size steps, and within each step will vmap the function over the batch. lax.map has less flexible semantics than jax.vmap, but for the simplest cases they look relatively similar. Here's an example using your calcPairAffinity function:

    For example

    import jax
    import jax.numpy as jnp
    
    def calcPairAffinity(ind1, ind2, imgs, gamma=0.5):
        image1, image2 = imgs[ind1], imgs[ind2]
        diff = jnp.sum(jnp.abs(image1 - image2))  
        normed_diff = diff / image1.size
        val = jnp.exp(-gamma*normed_diff)
        val = val.astype(jnp.float16)
        return val
    
    imgs = jax.random.normal(jax.random.key(0), (100, 5, 5))
    inds = jnp.arange(imgs.shape[0])
    inds1, inds2 = map(jnp.ravel, jnp.meshgrid(inds, inds))
    
    def f(inds):
      return calcPairAffinity(*inds, imgs, 0.5)
    
    
    result_vmap = jax.vmap(f)((inds1, inds2))
    result_batched = jax.lax.map(f, (inds1, inds2), batch_size=1000)
    assert jnp.allclose(result_vmap, result_batched)
    

    Original answer

    This is a frequent request, but unfortunately there's not yet (as of JAX version 0.4.20) any built-in utility to do chunked/batched vmap (xmap does have some functionality along these lines, but is experimental/incomplete and I wouldn't recommend relying on it).

    Adding chunking to vmap is tracked in https://github.com/google/jax/issues/11319, and there's some code there that does a limited version of what you have in mind. Hopefully something like what you describe will be possible with JAX's built-in vmap soon. In the meantime, you might think about applying vmap to chunks manually in the way you describe in your question.