pythonnumpypython-typingmypypyright

Static type checking or IDE intelligence support for a numpy array/matrix shape possible?


Is it possible to have static type checking or IDE intelligence support for a numpy array/matrix shape?

For example, if I imagine something like this:

A_MxN: NDArray(3,2) = ...
B_NxM: NDArray(2,3) = ...

Even better would be:

N = 3
M = 2   
A_MxN: NDArray(M,N) = ...
B_NxM: NDArray(N,M) = ...

And if I assign A to B, I would like to have an IDE hint during development time (not runtime), that the shapes are different.

Something like:

A_MxN = B_NxM
Hint/Error: Declared shape 3,2 is not compatible with assigned shape 2,3

As mentioned by @simon, this seems to be possible:

M = Literal[3]
N = Literal[2]
A_MxN: np.ndarray[tuple[M,N], np.dtype[np.int32]]

But if I assign an array which does not fulfill the shape requirement, the linter does not throw an error. Does someone know if there is a typechecker like mypy or pyright which supports the feature?


Solution

  • It's possible to type the shape of an array, like was mentioned before. But at the moment (numpy 2.1.1), the shape-type of the ndarray is lost in most of numpys own functions.

    But the shape-typing support is gradually improving, and I'm actually personally involved in this.

    But that doesn't mean that you can't use shape-typing yet. If you write a function yourself, you can add shape-typing support to it without too much hassle. For instance, in Python 3.12 (with PEP 695 syntax) you can e.g. do:

    from typing import Any
    import numpy as np
    
    def get_shape[ShapeT: tuple[int, ...]](a: np.ndarray[ShapeT, Any]) -> ShapeT:
        return a.shape
    

    This will be valid in all static type-checkers with numpy>=2.1.

    If this current syntax is too verbose for you, you could use the lightweight optype (of which I'm the author) to make it more readable:

    from optype.numpy import Array, AtLeast0D
    
    def get_shape[ShapeT: AtLeast0D](a: Array[ShapeT]) -> ShapeT:
        return a.shape
    

    Array is a handy alias for ndarray, that uses (PEP 696) type parameter defaults. See the docs if you want to know the details.