pythonpython-typing

Correct type annotations for generator function that yields slices of the given sequence?


I'm using Python 3.13 and have this function:

def chunk(data, chunk_size: int):
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

I want to give it type annotations to indicate that it can work with bytes, bytearray, or a general collections.abc.Sequence of any kind, and have the return type be a Generator of the exact input type. I do not want the return type to be a union type of all possible inputs (e.g. bytes | bytearray | Sequence[T]) because that's overly-wide; I want the precise type that I happen to put in to come back out the other end. Calling chunk on a bytes should return Generator[bytes], etc.

Since bytes and bytearray both conform to Sequence[T], my first attempt was this:

def chunk[T](data: Sequence[T], chunk_size: int) -> Generator[Sequence[T]]:
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

But this has a covariance issue- the return type is Sequence[T], not bytes, and pyright complains when I pass the return into a function that takes a bytes parameter (def print_bytes(b: bytes) -> None: ...):

error: Argument of type "Sequence[int]" cannot be assigned to parameter "b" of type "bytes" in function "print_bytes"
    "Sequence[int]" is not assignable to "bytes" (reportArgumentType)

So then I tried using a type constraint: "chunk can take any Sequence and returns a Generator of that type."

def chunk[T: Sequence](data: T, chunk_size: int) -> Generator[T]:
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

This time, pyright complains about the function itself:

error: Return type of generator function must be compatible with "Generator[Sequence[Unknown], Any, Any]"
    "Generator[Sequence[Unknown], None, Unknown]" is not assignable to "Generator[T@chunk, None, None]"
      Type parameter "_YieldT_co@Generator" is covariant, but "Sequence[Unknown]" is not a subtype of "T@chunk"
        Type "Sequence[Unknown]" is not assignable to type "T@chunk" (reportReturnType)

I'll admit to not fully understanding the complaint here- We've established via the type constraint that T is a Sequence, but pyright doesn't like it and I'm assuming my code is at fault.

Using typing.overload works:

@typing.overload
def chunk[T: bytes | bytearray](data: T, chunk_size: int) -> Generator[T]: ...

@typing.overload
def chunk[T](data: Sequence[T], chunk_size: int) -> Generator[Sequence[T]]: ...

def chunk(data, chunk_size: int):
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

In this case, pyright is able to pick the correct overload for all of my uses, but this feels a little silly- there's 2x as much typing code as actual implementation code!

What are the correct type annotations for my chunk function that returns a Generator of the specific type I passed in?


Solution

  • You can define a Protocol that defines the behaviour when the object is sliced and then use that as the bound for your generic argument:

    from collections.abc import Generator, Sized
    from typing import Protocol, Self
    
    
    class Sliceable(Sized, Protocol):
        def __getitem__(self: Self, key: slice, /) -> Self: ...
    
    
    def chunk[T: Sliceable](data: T, chunk_size: int) -> Generator[T]:
        yield from (
            data[i : i + chunk_size]
            for i in range(0, len(data), chunk_size)
        )
    

    Which can be tested using:

    byte_value = b"0123456789"
    
    def print_bytes(b: bytes) -> None: ...
    
    for byte_ch in chunk(byte_value, 10):
        print_bytes(byte_ch)
    
    str_value = "abcdefghijklmnopq"
    
    def print_string(b: str) -> None: ...
    
    for str_ch in chunk(str_value, 10):
        print_string(str_ch)
    
    list_value = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
    
    def print_list(b: list) -> None: ...
    
    for list_ch in chunk(list_value, 10):
        print_list(list_ch)
    

    pyright fiddle

    mypy fiddle