I'm wondering if there is a good way to limit the memory usage for Jax's VMAP function? Equivalently, to vmap in batches at a time if that makes sense?
In my specific use case, I have a set of images and I'd like to calculate the affinity between each pair of images; so ~order((num_imgs)^2 * (img shape)) bytes of memory used all at once if I'm understanding vmap correctly (which gets huge since in my real example I have 10,000 100x100 images).
A basic example is:
def affininty_matrix_ex(n_arrays=10, img_size=5, key=jax.random.PRNGKey(0), gamma=jnp.array([0.5])):
arr_of_imgs = jax.random.normal(jax.random.PRNGKey(0), (n_arrays, img_size, img_size))
arr_of_indices = jnp.arange(n_arrays)
inds_1, inds_2 = zip(*combinations(arr_of_indices, 2))
v_cPA = jax.vmap(calcPairAffinity2, (0, 0, None, None), 0)
affinities = v_cPA(jnp.array(inds_1), jnp.array(inds_2), arr_of_imgs, gamma)
print()
print(jax.make_jaxpr(v_cPA)(jnp.array(inds_1), jnp.array(inds_2), arr_of_imgs, gamma))
affinities = affinities.reshape(-1)
arr = jnp.zeros((n_arrays, n_arrays), dtype=jnp.float16)
arr = arr.at[jnp.triu_indices(arr.shape[0], k=1)].set(affinities)
arr = arr + arr.T
arr = arr + jnp.identity(n_arrays, dtype=jnp.float16)
return arr
def calcPairAffinity2(ind1, ind2, imgs, gamma):
#Returns a jnp array of 1 float, jnp.sum adds all elements together
image1, image2 = imgs[ind1], imgs[ind2]
diff = jnp.sum(jnp.abs(image1 - image2))
normed_diff = diff / image1.size
val = jnp.exp(-gamma*normed_diff)
val = val.astype(jnp.float16)
return val
I suppose I could just say something like "only feed into vmap X pairs at a time, and loop through n_chunks = n_arrays/X, appending each groups results to a list" but that doesn't seem to be ideal. My understanding is vmap does not like generators, not sure if that would be an alternative way around the issue.
Edit, Aug 13 2024
As of JAX version 0.4.31, what you're asking for is possible using the batch_size
argument of lax.map
.
For an iterable of size N
, this will perform a scan with N // batch_size
steps, and within each step will vmap
the function over the batch. lax.map
has less flexible semantics than jax.vmap
, but for the simplest cases they look relatively similar. Here's an example using your calcPairAffinity
function:
For example
import jax
import jax.numpy as jnp
def calcPairAffinity(ind1, ind2, imgs, gamma=0.5):
image1, image2 = imgs[ind1], imgs[ind2]
diff = jnp.sum(jnp.abs(image1 - image2))
normed_diff = diff / image1.size
val = jnp.exp(-gamma*normed_diff)
val = val.astype(jnp.float16)
return val
imgs = jax.random.normal(jax.random.key(0), (100, 5, 5))
inds = jnp.arange(imgs.shape[0])
inds1, inds2 = map(jnp.ravel, jnp.meshgrid(inds, inds))
def f(inds):
return calcPairAffinity(*inds, imgs, 0.5)
result_vmap = jax.vmap(f)((inds1, inds2))
result_batched = jax.lax.map(f, (inds1, inds2), batch_size=1000)
assert jnp.allclose(result_vmap, result_batched)
Original answer
This is a frequent request, but unfortunately there's not yet (as of JAX version 0.4.20) any built-in utility to do chunked/batched vmap (xmap
does have some functionality along these lines, but is experimental/incomplete and I wouldn't recommend relying on it).
Adding chunking to vmap
is tracked in https://github.com/google/jax/issues/11319, and there's some code there that does a limited version of what you have in mind. Hopefully something like what you describe will be possible with JAX's built-in vmap
soon. In the meantime, you might think about applying vmap
to chunks manually in the way you describe in your question.