pythongenericspython-typingfactory-method

How to type hint a python factory method returning different types?


I am working on a generic framework with the goal to solve different but related problems. A problem consists of data and a bunch of algorithms operating on this data. Data and algorithms may vary from problem to problem, so I need different classes. But they all share a common interface.

I start with a config-file defining the problem. At one point in my program I need a function/method that returns instances of different classes depending on the value (not the type) of a parameter.

The signatures look like this:

from dataclasses import dataclass
from typing import Protocol


# Protocols
class BaseData(Protocol):
    common: int


class BaseAlg[D: BaseData](Protocol):
    def update(self, data: D) -> None: ...


# Implementations data
@dataclass
class Data1:
    common: int
    extra: int


@dataclass
class Data2:
    common: int
    extra: str


# Implementations algorithms
class Alg1:
    def update(self, data: Data1) -> None:
        data.extra += data.common

class Alg2a:
    def update(self, data: Data2) -> None:
        data.extra *= data.common

class Alg2b:
    def update(self, data: Data2) -> None:
        data.extra += "2b"

No I want a factory initializing the algorithms and the data (omitted here) for each problem.

class FactoryAlgorithms:

    def _create_1(self) -> list[BaseAlg[Data1]]:
        return [Alg1()]

    def _create_2(self) -> list[BaseAlg[Data2]]:
        return [Alg2a(), Alg2b()]

    def create(self, type_alg: int): # <- How to annotate the return type?
        match type_alg:
            case 1:
                return self._create_1()
            case 2:
                return self._create_2()
            case _:
                raise ValueError(f"Unknown type of data {type_alg}")

How do I annotate the return type of the generic create-method?

mypy accepts list[BaseAlg[Data1]] | list[BaseAlg[Data2]] but

  1. This gets tedious as more and more business logic (algorithms and data structures) are added.
  2. This explicit typing doesn't really reflect what I want to return: A bunch of algorithms, all operating on the same data.

Intuitively I would write list[BaseAlg[BaseData]] which is rejected by mypy, I guess for covariance/contravariance reasons:

Incompatible return value type (got "list[BaseAlg[Data1]]", expected "list[BaseAlg[BaseData]]")

Is there way to tackle this with generics? Or is this design fundamentally flawed?


Solution

  • Why the simple-looking list[BaseAlg[BaseData]] doesn’t work

    D = TypeVar("D", bound=BaseData, contravariant=True)
    
    class BaseAlg(Protocol[D]):
        def update(self, data: D) -> None: ...
    

    Result: mypy quite correctly refuses to accept list[BaseAlg[BaseData]] as a super-type of the two concrete lists you return.

    Two practical ways to type the factory

    1. Overloads with Literal keys (idiomatic for factories)

    from typing import overload, Literal
    
    AlgList1 = list[BaseAlg[Data1]]
    AlgList2 = list[BaseAlg[Data2]]
    
    class FactoryAlgorithms:
    
        def _create_1(self) -> AlgList1:
            return [Alg1()]
    
        def _create_2(self) -> AlgList2:
            return [Alg2a(), Alg2b()]
    
        @overload
        def create(self, type_alg: Literal[1]) -> AlgList1: ...
        @overload
        def create(self, type_alg: Literal[2]) -> AlgList2: ...
        @overload
        def create(self, type_alg: int) -> list[BaseAlg[BaseData]]: ...
    
        def create(self, type_alg: int) -> list[BaseAlg[BaseData]]:
            if type_alg == 1:
                return self._create_1()
            if type_alg == 2:
                return self._create_2()
            raise ValueError(f"Unknown type {type_alg}")
    

    2. Return a typed Protocol that hides the concrete list

    If the callers never mutate the collection (they just iterate / call update), return an immutable, covariant view:

    from typing import Protocol, Sequence, runtime_checkable
    
    @runtime_checkable
    class AlgGroup(Protocol):
        def __iter__(self) -> Iterator[BaseAlg[BaseData]]: ...
        # add only the read-only operations you need
    
    class FactoryAlgorithms:
        ...
        def create(self, type_alg: int) -> AlgGroup:
            match type_alg:
                case 1:
                    return tuple(self._create_1())      # Sequence is covariant
                case 2:
                    return tuple(self._create_2())
                case _:
                    raise ValueError
    

    Now you sidestep list invariance by returning a tuple (or any Sequence) and exposing only the read-only interface.