pythonnumpypython-typing

Numpy Typing with specific shape and datatype


Currently i'm trying to work more with numpy typing to make my code clearer however i've somehow reached a limit that i can't currently override.

Is it possible to specify a specific shape and also the corresponding data type? Example:

Shape=(4,)
datatype= np.int32

My attempts so far look like the following (but all just threw errors):

First attempt:

import numpy as np

def foo(x: np.ndarray[(4,), np.dtype[np.int32]]):
...
result -> 'numpy._DTypeMeta' object is not subscriptable

Second attempt:

import numpy as np
import numpy.typing as npt

def foo(x: npt.NDArray[(4,), np.int32]):
...
result -> Too many arguments for numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]]

Also, unfortunately, I can't find any information about it in the documentation or I only get errors when I implement it the way it is documented.


Solution

  • Currently, numpy.typing.NDArray only accepts a dtype, like so: numpy.typing.NDArray[numpy.int32]. You have some options though.

    Use typing.Annotated

    typing.Annotated allows you to create an alias for a type and to bundle some extra information with it.

    In some my_types.py you would write all variations of shapes you want to hint:

    from typing import Annotated, Literal, TypeVar
    import numpy as np
    import numpy.typing as npt
    
    
    DType = TypeVar("DType", bound=np.generic)
    
    Array4 = Annotated[npt.NDArray[DType], Literal[4]]
    Array3x3 = Annotated[npt.NDArray[DType], Literal[3, 3]]
    ArrayNxNx3 = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]
    

    And then in foo.py, you can supply a numpy dtype and use them as typehint:

    import numpy as np
    from my_types import Array4
    
    
    def foo(arr: Array4[np.int32]):
        assert arr.shape == (4,)
    

    MyPy will recognize arr to be an np.ndarray and will check it as such. The shape checking can be done at runtime only, like in this example with an assert.

    If you don't like the assertion, you can use your creativity to define a function to do the checking for you.

    def assert_match(arr, array_type):
        hinted_shape = array_type.__metadata__[0].__args__
        hinted_dtype_type = array_type.__args__[0].__args__[1]
        hinted_dtype = hinted_dtype_type.__args__[0]
        assert np.issubdtype(arr.dtype, hinted_dtype), "DType does not match"
        assert arr.shape == hinted_shape, "Shape does not match"
    
    
    assert_match(some_array, Array4[np.int32])
    

    Use nptyping

    Another option would be to use 3th party lib nptyping (yes, I am the author).

    You would drop my_types.py as it would be of no use anymore.

    Your foo.py would become something like:

    from nptyping import NDArray, Shape, Int32
    
    
    def foo(arr: NDArray[Shape["4"], Int32]):
        assert isinstance(arr, NDArray[Shape["4"], Int32])
    

    Use beartype + typing.Annotated

    There is also another 3th party lib called beartype that you could use. It can take a variant of the typing.Annotated approach and will do the runtime checking for you.

    You would reinstate your my_types.py with content similar to:

    from beartype import beartype
    from beartype.vale import Is
    from typing import Annotated
    import numpy as np
    
    
    Int32Array4 = Annotated[np.ndarray, Is[lambda array:
        array.shape == (4,) and np.issubdtype(array.dtype, np.int32)]]
    Int32Array3x3 = Annotated[np.ndarray, Is[lambda array:
        array.shape == (3,3) and np.issubdtype(array.dtype, np.int32)]]
    

    And your foo.py would become:

    import numpy as np
    from beartype import beartype
    from my_types import Int32Array4 
    
    
    @beartype
    def foo(arr: Int32Array4):
        ...  # Runtime type checked by beartype.
    

    Use beartype + nptyping

    You could also stack up both libraries.

    Your my_types.py can be removed again and your foo.py would become something like:

    from nptyping import NDArray, Shape, Int32
    from beartype import beartype
    
    
    @beartype
    def foo(arr: NDArray[Shape["4"], Int32]):
        ...  # Runtime type checked by beartype.