pythonpython-typingapi-designdependent-type

How do I annotate a function whose return type depends on its argument?


In Python, I often write functions that filter a collection to find instances of specific subtypes. For example I might look for a specific kind of nodes in a DOM or a specific kind of events in a log:

def find_pre(soup: TagSoup) -> List[tags.pre]:
    """Find all <pre> nodes in `tag_soup`."""
    …

def filter_errors(log: List[LogEvent]) -> List[LogError]:
    """Keep only errors from `log`.""" 
    …

Writing types for these functions is easy. But what about generic versions of these functions that take an argument to specify which types to return?

def find_tags(tag_soup: TagSoup, T: type) -> List[T]:
    """Find all nodes of type `T` in `tag_soup`."""
    …

def filter_errors(log: List[LogEvent], T: type) -> List[T]:
    """Keep only events of type `T` from `log`.""" 
    …

(The signatures above are wrong: I can't refer to T in the return type.)

This is a fairly common design: docutils has node.traverse(T: type), BeautifulSoup has soup.find_all(), etc. Of course it can get arbitrarily complex, but can Python type annotations handle simple cases like the above?

Here is a MWE to make it very concrete:

from dataclasses import dataclass
from typing import *

@dataclass
class Packet: pass

@dataclass
class Done(Packet): pass

@dataclass
class Exn(Packet):
    exn: str
    loc: Tuple[int, int]

@dataclass
class Message(Packet):
    ref: int
    msg: str

Stream = Callable[[], Union[Packet, None]]

def stream_response(stream: Stream, types) -> Iterator[??]:
    while response := stream():
        if isinstance(response, Done): return
        if isinstance(response, types): yield response

def print_messages(stream: Stream):
    for m in stream_response(stream, Message):
        print(m.msg) # Error: Cannot access member "msg" for "Packet"

msgs = iter((Message(0, "hello"), Exn("Oops", (1, 42)), Done()))
print_messages(lambda: next(msgs))

Pyright says:

  29:17 - error: Cannot access member "msg" for type "Packet"
  Member "msg" is unknown (reportGeneralTypeIssues)

In the example above, is there a way to annotate stream_response so that Python type checkers will accept the definition of print_messages?


Solution

  • Okay, here we go. It passes MyPy --strict, but it isn't pretty.

    What's going on here

    For a given class A, we know that the type of an instance of A will be A (obviously). But what is the type of A itself? Technically, the type of A is type, as all python classes that don't use metaclassses are instances of type. However, annotating an argument with type doesn't tell the type-checker much. The syntax used for python type-checking to go "one step up" in the type hierarchy is, instead, Type[A]. So if we have a function myfunc that returns an instance of a class inputted as a parameter, we can fairly simply annotate that as follows:

    from typing import TypeVar, Type
    
    T = TypeVar('T')
    
    def myfunc(some_class: Type[T]) -> T:
        # do some stuff
        return some_class()
    

    Your case, however, is rather more complex. You could be inputting one class as a parameter, or you could be inputting two classes, or three classes... etc. We can solve this problem using typing.overload, which allows us to register multiple signatures for a given function. These signatures are ignored entirely at runtime; they are purely for the type-checker; as such, the bodies of these functions can be left empty. Generally, you only put a docstring or a literal ellipsis ... in the body of functions decorated with @overload.

    I don't think there's a way of generalising these overloaded functions, which is why the maximum number of elements that could be passed into the types parameter is important. You have to tediously enumerate every possible signature of your function. You may want to think about moving the @overload signatures to a separate .pyi stub file if you go down this route.

    from dataclasses import dataclass
    from typing import (
        Callable,
        Tuple,
        Union,
        Iterator,
        overload,
        TypeVar,
        Type, 
        Sequence
    )
    
    @dataclass
    class Packet: pass
    
    P1 = TypeVar('P1', bound=Packet)
    P2 = TypeVar('P2', bound=Packet)
    P3 = TypeVar('P3', bound=Packet)
    P4 = TypeVar('P4', bound=Packet)
    P5 = TypeVar('P5', bound=Packet)
    P6 = TypeVar('P6', bound=Packet)
    P7 = TypeVar('P7', bound=Packet)
    P8 = TypeVar('P8', bound=Packet)
    P9 = TypeVar('P9', bound=Packet)
    P10 = TypeVar('P10', bound=Packet)
    
    @dataclass
    class Done(Packet): pass
    
    @dataclass
    class Exn(Packet):
        exn: str
        loc: Tuple[int, int]
    
    @dataclass
    class Message(Packet):
        ref: int
        msg: str
    
    Stream = Callable[[], Union[Packet, None]]
    
    @overload
    def stream_response(stream: Stream, types: Type[P1]) -> Iterator[P1]:
        """Signature if exactly one type is passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2]]
    ) -> Iterator[Union[P1, P2]]:
        """Signature if exactly two types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3]]
    ) -> Iterator[Union[P1, P2, P3]]:
        """Signature if exactly three types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3], Type[P4]]
    ) -> Iterator[Union[P1, P2, P3, P4]]:
        """Signature if exactly four types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3], Type[P4], Type[P5]]
    ) -> Iterator[Union[P1, P2, P3, P4, P5]]:
        """Signature if exactly five types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3], Type[P4], Type[P5], Type[P6]]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6]]:
        """Signature if exactly six types are passed in for the `types` parameter"""
    
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[
            Type[P1], 
            Type[P2],
            Type[P3],
            Type[P4], 
            Type[P5],
            Type[P6],
            Type[P7]
        ]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7]]:
        """Signature if exactly seven types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[
            Type[P1], 
            Type[P2],
            Type[P3],
            Type[P4], 
            Type[P5],
            Type[P6],
            Type[P7],
            Type[P8]
        ]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7, P8]]:
        """Signature if exactly eight types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[
            Type[P1], 
            Type[P2],
            Type[P3],
            Type[P4], 
            Type[P5],
            Type[P6],
            Type[P7],
            Type[P8],
            Type[P9]
        ]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7, P8, P9]]:
        """Signature if exactly nine types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[
            Type[P1], 
            Type[P2],
            Type[P3],
            Type[P4], 
            Type[P5],
            Type[P6],
            Type[P7],
            Type[P8],
            Type[P9],
            Type[P10]
        ]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7, P8, P9, P10]]:
        """Signature if exactly ten types are passed in for the `types` parameter"""
    
    # We have to be more generic in our type-hinting for the concrete implementation 
    # Otherwise, MyPy struggles to figure out that it's a valid argument to `isinstance`
    def stream_response(
        stream: Stream,
        types: Union[type, Tuple[type, ...]]
    ) -> Iterator[Packet]:
        
        while response := stream():
            if isinstance(response, Done): return
            if isinstance(response, types): yield response
    
    def print_messages(stream: Stream) -> None:
        for m in stream_response(stream, Message):
            print(m.msg)
    
    msgs = iter((Message(0, "hello"), Exn("Oops", (1, 42)), Done()))
    print_messages(lambda: next(msgs))
    

    Strategies for making this less verbose

    If you wanted to make this more concise, one way of achieving that is to introduce an alias for certain typing constructs. The danger here is that the intent and meaning of the type hint gets quite difficult to read, but it does make overloads 7-10 look a lot less horrific:

    from dataclasses import dataclass
    from typing import (
        Callable,
        Tuple,
        Union,
        Iterator,
        overload,
        TypeVar,
        Type, 
        Sequence
    )
    
    @dataclass
    class Packet: pass
    
    P1 = TypeVar('P1', bound=Packet)
    P2 = TypeVar('P2', bound=Packet)
    P3 = TypeVar('P3', bound=Packet)
    P4 = TypeVar('P4', bound=Packet)
    P5 = TypeVar('P5', bound=Packet)
    P6 = TypeVar('P6', bound=Packet)
    P7 = TypeVar('P7', bound=Packet)
    P8 = TypeVar('P8', bound=Packet)
    P9 = TypeVar('P9', bound=Packet)
    P10 = TypeVar('P10', bound=Packet)
    
    _P = TypeVar('_P', bound=Packet)
    S = Type[_P]
    
    T7 = Tuple[S[P1], S[P2], S[P3], S[P4], S[P5], S[P6], S[P7]]
    T8 = Tuple[S[P1], S[P2], S[P3], S[P4], S[P5], S[P6], S[P7], S[P8]]
    T9 = Tuple[S[P1], S[P2], S[P3], S[P4], S[P5], S[P6], S[P7], S[P8], S[P9]]
    T10 = Tuple[S[P1], S[P2], S[P3], S[P4], S[P5], S[P6], S[P7], S[P8], S[P9], S[P10]]
    
    @dataclass
    class Done(Packet): pass
    
    @dataclass
    class Exn(Packet):
        exn: str
        loc: Tuple[int, int]
    
    @dataclass
    class Message(Packet):
        ref: int
        msg: str
    
    Stream = Callable[[], Union[Packet, None]]
    
    @overload
    def stream_response(stream: Stream, types: Type[P1]) -> Iterator[P1]:
        """Signature if exactly one type is passed in for the `types` parameter"""
    
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2]]
    ) -> Iterator[Union[P1, P2]]:
        """Signature if exactly two types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3]]
    ) -> Iterator[Union[P1, P2, P3]]:
        """Signature if exactly three types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3], Type[P4]]
    ) -> Iterator[Union[P1, P2, P3, P4]]:
        """Signature if exactly four types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3], Type[P4], Type[P5]]
    ) -> Iterator[Union[P1, P2, P3, P4, P5]]:
        """Signature if exactly five types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: Tuple[Type[P1], Type[P2], Type[P3], Type[P4], Type[P5], Type[P6]]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6]]:
        """Signature if exactly six types are passed in for the `types` parameter"""
    
    @overload
    def stream_response(
        stream: Stream, 
        types: T7[P1, P2, P3, P4, P5, P6, P7]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7]]:
        """Signature if exactly seven types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: T8[P1, P2, P3, P4, P5, P6, P7, P8]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7, P8]]:
        """Signature if exactly eight types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: T9[P1, P2, P3, P4, P5, P6, P7, P8, P9]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7, P8, P9]]:
        """Signature if exactly nine types are passed in for the `types` parameter"""
        
    @overload
    def stream_response(
        stream: Stream, 
        types: T10[P1, P2, P3, P4, P5, P6, P7, P8, P9, P10]
    ) -> Iterator[Union[P1, P2, P3, P4, P5, P6, P7, P8, P9, P10]]:
        """Signature if exactly ten types are passed in for the `types` parameter"""
    
    # We have to be more generic in our type-hinting for the concrete implementation 
    # Otherwise, MyPy struggles to figure out that it's a valid argument to `isinstance`
    def stream_response(
        stream: Stream,
        types: Union[type, Tuple[type, ...]]
    ) -> Iterator[Packet]:
        
        while response := stream():
            if isinstance(response, Done): return
            if isinstance(response, types): yield response
    
    def print_messages(stream: Stream) -> None:
        for m in stream_response(stream, Message):
            print(m.msg)
    
    msgs = iter((Message(0, "hello"), Exn("Oops", (1, 42)), Done()))
    print_messages(lambda: next(msgs))