jax

Using JAX ndarray.at apply(ufunc) with arguments


Can arguments be passed to a jax.numpy.ufunc within a jax.numpy.ndarray.at call?

The following is an attempt to replicate jax.numpy.ndarray.at[...].add(...)

import jax.numpy as jnp

def myadd(a,b=1):
    return a+b

umyadd = jnp.frompyfunc(myadd,2,1,identity=0)

x = jnp.arange(4)

# call jnp.add(x,x)
x.at[:].add(x)
# [0 2 4 6]

# call umyadd.at
umyadd.at(x, np.arange(x.size), x, inplace=False)
# [0 2 4 6]

# Default b=1 (can b be passed here?)
x.at[:].apply(umyadd)
# [1 2 3 4]


Solution

  • arr.at[...].apply() only accepts unary functions that map a scalar to a scalar. So you could pass b via closure, as long as it's a scalar; for example:

    x.at[:].apply(lambda a: umyadd(a, 2))
    # [2, 3, 4, 5]
    

    But there is no way to pass b=jnp.arange(4) within apply(), because then the applied function no longer maps a scalar to a scalar.