Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Correct type annotations for generator function that yields slices of the given sequence?

I'm using Python 3.13 and have this function:

def chunk(data, chunk_size: int):
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

I want to give it type annotations to indicate that it can work with bytes, bytearray, or a general collections.abc.Sequence of any kind, and have the return type be a Generator of the exact input type. I do not want the return type to be a union type of all possible inputs (e.g. bytes | bytearray | Sequence[T]) because that's overly-wide; I want the precise type that I happen to put in to come back out the other end. Calling chunk on a bytes should return Generator[bytes], etc.

Since bytes and bytearray both conform to Sequence[T], my first attempt was this:

def chunk[T](data: Sequence[T], chunk_size: int) -> Generator[Sequence[T]]:
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

But this has a covariance issue- the return type is Sequence[T], not bytes, and pyright complains when I pass the return into a function that takes a bytes parameter (def print_bytes(b: bytes) -> None: ...):

error: Argument of type "Sequence[int]" cannot be assigned to parameter "b" of type "bytes" in function "print_bytes"
    "Sequence[int]" is not assignable to "bytes" (reportArgumentType)

So then I tried using a type constraint: "chunk can take any Sequence and returns a Generator of that type."

def chunk[T: Sequence](data: T, chunk_size: int) -> Generator[T]:
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

This time, pyright complains about the function itself:

error: Return type of generator function must be compatible with "Generator[Sequence[Unknown], Any, Any]"
    "Generator[Sequence[Unknown], None, Unknown]" is not assignable to "Generator[T@chunk, None, None]"
      Type parameter "_YieldT_co@Generator" is covariant, but "Sequence[Unknown]" is not a subtype of "T@chunk"
        Type "Sequence[Unknown]" is not assignable to type "T@chunk" (reportReturnType)

I'll admit to not fully understanding the complaint here- We've established via the type constraint that T is a Sequence, but pyright doesn't like it and I'm assuming my code is at fault.

Using typing.overload works:

@typing.overload
def chunk[T: bytes | bytearray](data: T, chunk_size: int) -> Generator[T]: ...

@typing.overload
def chunk[T](data: Sequence[T], chunk_size: int) -> Generator[Sequence[T]]: ...

def chunk(data, chunk_size: int):
    yield from (data[i : i + chunk_size] for i in range(0, len(data), chunk_size))

In this case, pyright is able to pick the correct overload for all of my uses, but this feels a little silly- there's 2x as much typing code as actual implementation code!

What are the correct type annotations for my chunk function that returns a Generator of the specific type I passed in?

like image 968
Charles Nicholson Avatar asked Nov 17 '25 07:11

Charles Nicholson


1 Answers

You can define a Protocol that defines the behaviour when the object is sliced and then use that as the bound for your generic argument:

from collections.abc import Generator, Sized
from typing import Protocol, Self


class Sliceable(Sized, Protocol):
    def __getitem__(self: Self, key: slice, /) -> Self: ...


def chunk[T: Sliceable](data: T, chunk_size: int) -> Generator[T]:
    yield from (
        data[i : i + chunk_size]
        for i in range(0, len(data), chunk_size)
    )

Which can be tested using:

byte_value = b"0123456789"

def print_bytes(b: bytes) -> None: ...

for byte_ch in chunk(byte_value, 10):
    print_bytes(byte_ch)

str_value = "abcdefghijklmnopq"

def print_string(b: str) -> None: ...

for str_ch in chunk(str_value, 10):
    print_string(str_ch)

list_value = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

def print_list(b: list) -> None: ...

for list_ch in chunk(list_value, 10):
    print_list(list_ch)

pyright fiddle

mypy fiddle

like image 59
MT0 Avatar answered Nov 19 '25 21:11

MT0