pythonpython-typing

How can I annotate a function that takes a union, and returns one of the types in the union?


Suppose I want to annotate this function:

def add_one(value):
    match value:
        case int():
            return value + 1
        case str():
            return value + " and one more"
        case _:
            raise TypeError()

I want to tell the type checker "This function can be called with an int (or subclass) or a str (or subclass). In the former case it returns an int, and in the latter a str." How can this be accomplished?

Here are some of my failed attempts:

Attempt 1

def add_one(value: int | str) -> int | str: ...

This is too loose. The type checker no longer knows that the returned type is similar to the argument. Passing an int might return a str.

Attempt 2

def add_one[T: int | str](value: T) -> T: ...

This is incorrect. It doesn't return literally the same type as the argument. If passed an IntEnum it returns int, and for a StrEnum it returns str.

Attempt 3

def add_one[T: (int, str)](value: T) -> T: ...

This is better, but now I can't call add_one with Union[int, str].

These ques-tions talk about the differences between bounds and constraints, but I lack the brainpower to use them to solve my problem.

Attempt 4

@overload
def add_one(value: int) -> int: ...

@overload
def add_one(value: str) -> str: ...

def add_one(value: int | str) -> int | str:
    # Put real implementation here
    ...

This is the best I can do. It does the right thing, but requires me to type out the function signature for every possible type it handles. Doesn't seem like much for two types, but my real code already has 7 or 8, and I intend to add more. It also requires me to manually expand any Union.

Is there some better way to tell the type checker "Here's a Union and a T. Make the T be whatever union arg matched."?


Solution

  • One solution I can think of is to mix attempts 2, 3 and 4:

    (playgrounds: Mypy, Pyright)

    @overload
    def add_one[T: (str, int, bytes)](value: T) -> T: ...
    
    @overload
    def add_one[T: str | int | bytes](value: T) -> T: ...
    
    def add_one(value: str | int | bytes) -> str | int | bytes: ...
    
    class S(StrEnum):
        A = ''
    
    class I(IntEnum):
        B = 0
        
    class B(bytes, Enum):
        C = b''
    
    def f(si: str | int, sb: str | bytes, ib: int | bytes, sib: str | int | bytes) -> None:
        reveal_type(add_one(''))    # str
        reveal_type(add_one(0))     # int
        reveal_type(add_one(b''))   # bytes
        
        reveal_type(add_one(S.A))   # str
        reveal_type(add_one(I.B))   # int
        reveal_type(add_one(B.C))   # bytes
        
        reveal_type(add_one(si))    # str | int
        reveal_type(add_one(sb))    # str | bytes
        reveal_type(add_one(ib))    # int | bytes
        reveal_type(add_one(sib))   # str | int | bytes
    

    This has one minor problem in that the second overload would be marked as overlapping with the first. It should be fine to # type: ignore it.

    Additionally, unions of subtypes like S | I won't be handled correctly; an upcast (v: str | int = S.A if bool() else I.B) would be necessary to make it work.