pythonpytorchpython-typingjax

Hello World for jaxtyping?


I can't find any instructions or tutorials for getting started with jaxtyping. I tried the simplest possible program and it fails to parse. I'm on Python 3.11. I don't see anything on GitHub jaxtyping project about an upper bound (lower bound is Python 3.9) and it looks like it's actively maintained (last commit was 8 hours ago). What step am I missing?

jaxtyping==0.2.36
numpy==2.1.3
torch==2.5.1
typeguard==4.4.1

(It seems like numpy is required for some reason even though I'm not using it)

from typeguard import typechecked
from jaxtyping import Float
from torch import Tensor


@typechecked
def matmul(a: Float[Tensor, "m n"], b: Float[Tensor, "n p"]) -> Float[Tensor, "m p"]:
    """
    Matrix multiplication of two 2D arrays.
    """
    raise NotImplementedError("This function is not implemented yet.")
(venv) dspyz@dspyz-desktop:~/helloworld$ python matmul.py 
Traceback (most recent call last):
  File "/home/dspyz/helloworld/matmul.py", line 6, in <module>
    @typechecked
     ^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_decorators.py", line 221, in typechecked
    retval = instrument(target)
             ^^^^^^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_decorators.py", line 72, in instrument
    instrumentor.visit(module_ast)
  File "/usr/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 598, in visit_Module
    self.generic_visit(node)
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 498, in generic_visit
    node = super().generic_visit(node)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/ast.py", line 494, in generic_visit
    value = self.visit(value)
            ^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 672, in visit_FunctionDef
    with self._use_memo(node):
  File "/usr/lib/python3.11/contextlib.py", line 137, in __enter__
    return next(self.gen)
           ^^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 556, in _use_memo
    new_memo.return_annotation = self._convert_annotation(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 582, in _convert_annotation
    new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 355, in visit
    new_node = super().visit(node)
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 421, in visit_Subscript
    [self.visit(item) for item in node.slice.elts],
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 421, in <listcomp>
    [self.visit(item) for item in node.slice.elts],
     ^^^^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 355, in visit
    new_node = super().visit(node)
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/dspyz/helloworld/venv/lib/python3.11/site-packages/typeguard/_transformer.py", line 474, in visit_Constant
    expression = ast.parse(node.value, mode="eval")
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/ast.py", line 50, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    m p
      ^
SyntaxError: invalid syntax

Solution

  • (jaxtyping author here)

    Sadly this is a known bug in typeguard v4. It's been around forever and hasn't been fixed. (At a technical level: typeguard v4 attempts to load and reparse the source code of your function, but it doesn't properly parse all type annotations.)

    I use typeguard==2.13.3 myself, which seems to be pretty robust.

    EDIT: removed some other suggested workarounds. These turned out not to, well, work. For now I just recommend pinning to that earlier version of typeguard.