pythonmypyprotocol-oriented

How to get Mypy to recognize a class's protocol membership within a Callable?


Mypy properly recognizes a class's adherence to a protocol when the protocol is used as a simple parameter to a type-annotated function. However, when I have a function requiring a callable parameter using that protocol, Mypy misses the user class's protocol membership.

Am I misusing Mypy's protocol pattern, or is this something simply not supported by Mypy at the moment?

(I have seen the thread about Mypy having trouble with Callables that get assigned to a class.. so this may be a known behavior)

from typing_extensions import Protocol
from typing import Callable

class P(Protocol) :
    def foo(self) -> None : ...


def requires_P(protocol_member : P) -> None : 
    protocol_member.foo()

def requires_P_callable(protocol_member : P, function: Callable[[P],None]) -> None :
    function(protocol_member)



class C :
    def foo(self) :
        print("bar")

if __name__ == '__main__' :

    c = C()

    def call_foo(c: C) -> None: 
        c.foo()

    requires_P(c) 
                # mypy is fine with this

    requires_P_callable(c, call_foo) 
                # mypy complains : 
                #       Argument 2 to "requires_P_callable" has incompatible type "Callable[[C], None]"; expected "Callable[[P], None]"




Solution

  • If replace definition of call_foo with:

    def call_foo(c: P) -> None: c.foo()

    error disappears and program continue to work... The situation is the same if stop using Protocol and make C a child of P.

    Second workaround is:

    from typing import Callable, Protocol, TypeVar
    
    _TP = TypeVar('_TP', bound='P')
    
    
    class P(Protocol):
        def foo(self) -> None:
            ...
    
    
    class C:
    
        def foo(self) -> None:
            print("foo")
    
    
    def requires_P_callable(prot: _TP, func: Callable[[_TP], None]) -> None:
        func(prot)
    
    
    def call_foo(c: C) -> None:
        c.foo()
    
    
    if __name__ == '__main__':
    
        c = C()
        requires_P_callable(c, call_foo)