I want to type annotate a simple sorting function that receives a list of any values that have either __lt__
, __gt__
, or both, but not mixed methods (i.e., all values should have the same comparison method) and returns a list containing the same elements it received in sorted order.
What I've done so far:
from typing import Any, Protocol
class SupportsLT(Protocol):
def __lt__(self, other: Any, /) -> bool: ...
class SupportsGT(Protocol):
def __gt__(self, other: Any, /) -> bool: ...
def quick_sort[T: SupportsLT | SupportsGT](seq: list[T]) -> list[T]:
if len(seq) <= 1:
return list(seq)
pivot = seq[0]
smaller: list[T] = []
larger_or_equal: list[T] = []
for i in range(1, len(seq)):
item = seq[i]
if item < pivot:
smaller.append(item)
else:
larger_or_equal.append(item)
return [*quick_sort(smaller), pivot, *quick_sort(larger_or_equal)]
Running mypy --strict
gives the error: error: Unsupported left operand type for < (some union) [operator]
. I think this is because there is no guarantee that the values do not have mixed protocols.
Another attempt is using constraints instead of bound:
def quick_sort[T: (SupportsLT, SupportsGT)](seq: list[T]) -> list[T]:
if len(seq) <= 1:
return list(seq)
pivot = seq[0]
smaller: list[T] = []
larger_or_equal: list[T] = []
for i in range(1, len(seq)):
item = seq[i]
if item < pivot:
smaller.append(item)
else:
larger_or_equal.append(item)
return [*quick_sort(smaller), pivot, *quick_sort(larger_or_equal)]
When I run mypy --strict
, I get complaints about the type of the list returned by the return statement:
error: List item 0 has incompatible type "list[T]"; expected "SupportsLT" [list-item]
error: List item 0 has incompatible type "list[T]"; expected "SupportsGT" [list-item]
error: List item 2 has incompatible type "list[T]"; expected "SupportsLT" [list-item]
error: List item 2 has incompatible type "list[T]"; expected "SupportsGT" [list-item]
However, the second attempt has a significant issue. Running the following gives Revealed type is "builtins.list[SupportsLT]"
, which should be "builtins.list[builtins.int]"
instead:
from typing import reveal_type
x = quick_sort([12, 3])
reveal_type(x)
python version: 3.13.0, mypy version: 1.13.0
UPDATE:
The reason I'm using both SupportsLT
and SupportsGT
as bound while my function only utilizes the less-than operator (i.e., <
) is that when the left-hand operand lacks the __lt__
method, Python calls the __gt__
method of the right-hand operand and passes the left operand as an argument. Thus, values with only __gt__
should be considered valid as input to my function. Consider the following simple example:
from typing import Self
class HasGT:
def __init__(self, value: int) -> None:
self.value = value
def __gt__(self, other: Self) -> bool:
return self.value > other.value
print(HasGT(5) < HasGT(6)) # Prints True
Your first attempt is indeed unsafe. Let's see that:
class HasGT:
def __init__(self, value: int) -> None:
self.value = value
def __gt__(self, other: Self) -> bool:
return self.value > other.value
class HasLT:
def __init__(self, value: int) -> None:
self.value = value
def __lt__(self, other: Self) -> bool:
return self.value < other.value
foo: list[HasLT | HasGT] = [HasLT(2), HasGT(3)]
sorted_foo = quick_sort(foo)
mypy
accepts this part (the error points at your definition), but it fails at runtime:
$ mypy s.py --strict
s.py:17: error: Unsupported left operand type for < (some union) [operator]
Found 1 error in 1 file (checked 1 source file)
$ python s.py
Traceback (most recent call last):
File "/tmp/temp/s.py", line 56, in <module>
sorted_foo = quick_sort(foo)
^^^^^^^^^^^^^^^
File "/tmp/temp/s.py", line 17, in quick_sort
if item < pivot:
^^^^^^^^^^^^
TypeError: '<' not supported between instances of 'HasGT' and 'HasLT'
TypeVar bound to some type T can be substituted with any T1 <= T
. Where T
is a union type, there's nothing wrong with T1
being the same union type, that's explicitly allowed. So such implementation is unsafe.
Your second snippet is actually safe. There's a mypy
bug making it reject your function as-is (something shady happens during unpacking, I'll have a look later), but replacing the last line with
return quick_sort(smaller) + [pivot] + quick_sort(larger_or_equal)
fixes things. Such implementation passes mypy --strict
, but is barely useful as you noticed: it's return type will be just that, a protocol with a single comparison method.
I think that it's reasonable to say "my implementation is fine" and provide the best possible signature for callers. It's reasonable to assume that the input collection is homogeneous, so let's just make two overloads (and also avoid restricting the input to lists: it works for any sequence):
from collections.abc import Sequence
from typing import Any, Protocol, Self, overload
# [snip] Supports{L,G}T and Has{L,G}T definitions here
@overload
def quick_sort_overloaded[T: SupportsLT](seq: Sequence[T]) -> list[T]: ...
@overload
def quick_sort_overloaded[T: SupportsGT](seq: Sequence[T]) -> list[T]: ...
def quick_sort_overloaded[T: SupportsLT | SupportsGT](seq: Sequence[T]) -> list[T]:
if len(seq) <= 1:
return list(seq)
pivot = seq[0]
smaller: list[T] = []
larger_or_equal: list[T] = []
for i in range(1, len(seq)):
item = seq[i]
if item < pivot: # type: ignore[operator]
smaller.append(item)
else:
larger_or_equal.append(item)
return [
*quick_sort_overloaded(smaller), # type: ignore[type-var]
pivot,
*quick_sort_overloaded(larger_or_equal) # type: ignore[type-var]
]
And now
reveal_type(quick_sort_overloaded([HasLT(2), HasLT(3)])) # N: Revealed type is "builtins.list[__main__.HasLT]"
reveal_type(quick_sort_overloaded([HasGT(2), HasGT(3)])) # N: Revealed type is "builtins.list[__main__.HasGT]"
foo: list[HasLT | HasGT] = [HasLT(2), HasGT(3)]
try:
quick_sort_overloaded(foo) # E: Value of type variable "T" of "quick_sort_overloaded" cannot be "HasLT | HasGT" [type-var]
except TypeError:
print("`quick_sort_overloaded` failed as warned")
Here's a playground to compare all those solutions.