pythongenericsrecursionpython-typing

Mypy type narrowing with recursive generic types


Let's say I make a generic class whose objects only contain one value (of type T).

T = TypeVar('T')

class Contains(Generic[T]):
    val: T
    def __init__(self, val: T):
        self.val = val

Note that self.val can itself be a Contains object, so a recursive structure is possible. I want to define a function that will reduce such a structure to a single non-Contains object.

def flatten(x):
    while isinstance(x, Contains):
        x = x.val
    return x

What should the type signature of 'flatten' be?

I tried to make a recursive Nested type

Nested = T | Contains['Nested[T]']

but it confuses the type checker as T can also mean a Contains object.

def flatten(x: Nested[T]) -> T:
    while isinstance(x, Contains):
        reveal_type(x) # reveals Contains[Unknown] | Contains[Nested]
        x = x.val
    reveal_type(x) # reveals object* | Unknown
    return x

Another approach was to make a separate class

class Base(Generic[T]):
    self.val: T
    def __init__(self, val):
        self.val = val


Nested = Base[T] | Contains['Nested[T]']

def flatten(x: Nested[T]) -> T:
    while isinstance(x, Contains):
        x = x.val
    return x.val

This works, but you have to wrap the argument in a Base object every time, which is cumbersome. Furthermore, Base has the same behaviour as Contains, it's the same thing written twice! I tried to use NewType instead, but it isn't subscriptable.

Is there any nice (or, at least not too ugly) way to do it?


Solution

  • The problem is that static type checkers are typically unable to infer dynamic type changes from assignments such as:

    x = x.val
    

    As a workaround you can make flatten a recursive function instead to avoid an assignment to x:

    from typing import TypeVar, Generic, TypeAlias
    
    T = TypeVar('T')
    
    class Contains(Generic[T]):
        val: T
        def __init__(self, val: T):
            self.val = val
    
    Nested: TypeAlias = T | Contains['Nested[T]']
    
    def flatten(x: Nested[T]) -> T:
        if isinstance(x, Contains):
            return flatten(x.val)
        reveal_type(x) # Type of "x" is "object*"
        return x
    

    Demo with PyRight here

    Demo with mypy here