Consider the following file:
import jax.numpy as jnp
def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
Running mypy mypytest.py
returns the following error:
mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")
For some reason it believes adding two jax.numpy.ndarray
s returns a NumPy array of bools. Am I doing something wrong? Or is this a bug in MyPy, or Jax's type annotations?
At least statically, jnp.ndarray
is a subclass of np.ndarray
with very minimal modifications
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
As such, it inherits np.ndarray
's method type signatures.
I guess the runtime behaviour is achieved via the jnp.array
function. Unless I've missed some stub files or type trickery, the result of jnp.array
matches jnp.ndarray
simply because jnp.array
is untyped. You can test this out with
def foo(_: str) -> None:
pass
foo(jnp.array(0))
which passes mypy.
So to answer your questions, I don't think you're doing anything wrong. It's a bug in the sense that it's probably not what they mean, but it's not actually incorrect because you do get an np.ndarray
when you add jnp.ndarray
s because a jnp.ndarray
is an np.ndarray
.
As for why bool
s, that's likely because your jnp.array
s are missing generic parameters and the first valid overload for __add__
on np.ndarray
is
@overload
def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
so it's just defaulted to bool
.