pythonpython-typingmypyjax

Why does Mypy think adding two Jax arrays returns a numpy array?


Consider the following file:

import jax.numpy as jnp

def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    return a + b

Running mypy mypytest.py returns the following error:

mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")

For some reason it believes adding two jax.numpy.ndarrays returns a NumPy array of bools. Am I doing something wrong? Or is this a bug in MyPy, or Jax's type annotations?


Solution

  • At least statically, jnp.ndarray is a subclass of np.ndarray with very minimal modifications

    class ndarray(np.ndarray, metaclass=_ArrayMeta):
      dtype: np.dtype
      shape: Tuple[int, ...]
      size: int
    
      def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
                   order=None):
        raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                        " Use jax.numpy.array, or jax.numpy.zeros instead.")
    

    As such, it inherits np.ndarray's method type signatures.

    I guess the runtime behaviour is achieved via the jnp.array function. Unless I've missed some stub files or type trickery, the result of jnp.array matches jnp.ndarray simply because jnp.array is untyped. You can test this out with

    def foo(_: str) -> None:
       pass
    
    foo(jnp.array(0))
    

    which passes mypy.

    So to answer your questions, I don't think you're doing anything wrong. It's a bug in the sense that it's probably not what they mean, but it's not actually incorrect because you do get an np.ndarray when you add jnp.ndarrays because a jnp.ndarray is an np.ndarray.

    As for why bools, that's likely because your jnp.arrays are missing generic parameters and the first valid overload for __add__ on np.ndarray is

        @overload
        def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...  # type: ignore[misc]
    

    so it's just defaulted to bool.