Can arguments be passed to a jax.numpy.ufunc
within a jax.numpy.ndarray.at
call?
The following is an attempt to replicate jax.numpy.ndarray.at[...].add(...)
import jax.numpy as jnp
def myadd(a,b=1):
return a+b
umyadd = jnp.frompyfunc(myadd,2,1,identity=0)
x = jnp.arange(4)
# call jnp.add(x,x)
x.at[:].add(x)
# [0 2 4 6]
# call umyadd.at
umyadd.at(x, np.arange(x.size), x, inplace=False)
# [0 2 4 6]
# Default b=1 (can b be passed here?)
x.at[:].apply(umyadd)
# [1 2 3 4]
arr.at[...].apply()
only accepts unary functions that map a scalar to a scalar. So you could pass b
via closure, as long as it's a scalar; for example:
x.at[:].apply(lambda a: umyadd(a, 2))
# [2, 3, 4, 5]
But there is no way to pass b=jnp.arange(4)
within apply()
, because then the applied function no longer maps a scalar to a scalar.