Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrapper stripping the generic parameter of a function erases its type parameter

In the following Python code, I define a generic function wrapper which takes a function of type T → T and replaces it by a function without arguments returning an instance of Delay[T]. This instance simply stores the original function so that it can be called later.

from collections.abc import Callable

class Delay[T]:
    def __init__(self, wrapped: Callable[[T], T]):
        self.wrapped = wrapped

def wrapper[T](wrapped: Callable[[T], T]) -> Callable[[], Delay[T]]:
    def wrapping() -> Delay[T]:
        return Delay(wrapped)

    return wrapping

When using this wrapper with a normal function, the type checker is happy:

@wrapper
def fun1(arg: str) -> str:
    return arg

reveal_type(fun1) # mypy says: "def () -> Delay[builtins.str]"
reveal_type(fun1()) # mypy says: "Delay[builtins.str]"
reveal_type(fun1().wrapped) # mypy says: "def (builtins.str) -> builtins.str"
reveal_type(fun1().wrapped("test")) # mypy says: "builtins.str"

However, when the wrapped function is generic, the type argument somehow gets erased:

@wrapper
def fun2[T](arg: T) -> T:
    return arg

reveal_type(fun2) # mypy says: "def () -> Delay[Never]"
reveal_type(fun2()) # mypy says: "Delay[Never]"
reveal_type(fun2().wrapped) # mypy says: "def (Never) -> Never"
reveal_type(fun2().wrapped("test")) # mypy says: "Never"

I would have expected the type checker to infer the type of fun2 as def [T] () -> Delay[T], the type of fun2().wrapped as def [T] (T) -> T, and the type of the last line as str.

Note that pyright seems to exhibit similar behavior as mypy here.

Is there something invalid with the type annotations in my code? Is this a known limitation of the Python type system, or a bug in mypy and pyright?

like image 751
matteodelabre Avatar asked Sep 02 '25 04:09

matteodelabre


1 Answers

Based on what I think you're trying to do (mypy Playground with a hacky solution), I would say your annotations are invalid - you're trying to using the same symbol T to bind to different type variable scopes.


You already know that fun1: "def () -> Delay[builtins.str]" here ...

@wrapper
def fun1(arg: str) -> str:
    return arg

... but you cannot have fun2: "def () -> Delay[T]" here.

@wrapper
def fun2[T](arg: T) -> T:
    return arg

This is because fun2 is a variable at the module-scope, and module-scoped variables can't have types with a free type variable, because modules don't bind types (only generic classes and generic functions can bind types in their bodies). Something with type Delay[T] at the module scope can't ever be fulfilled; you can't create an instance of T at this scope.

What you're trying to do might be this:

  • If Delay can be parameterised by a concrete type at the module scope (fun1), then Delay.wrapped must be a callable which receives and returns an argument of this concrete type.
  • If Delay can't be parameterised by a concrete type at the module scope (fun2; T isn't a concrete type), then make Delay.wrapped return the same type it is given.

Delay[Never] indicates something that can't be parameterised by a concrete type at the module scope. Hence, a workaround is to introduce a descriptor type.

if TYPE_CHECKING:

    class Wrapped:
        @overload  # type: ignore[no-overload-impl]
        def __get__(self, instance: None, owner: type[object], /) -> Self: ...
        @overload
        def __get__[R](
            self, instance: Delay[Never], owner: type[Delay[Never]], /
        ) -> Callable[[R], R]: 
            """
            Can't be parameterised by a concrete type, return a callable which
            just returns the same type as it receives
            """
        @overload
        def __get__[T](
            self, instance: Delay[T], owner: type[Delay[T]], /
        ) -> Callable[[T], T]: 
            """
            Can be parameterised by a concrete type, return a callable which
            receives and returns this concrete type
            """
        def __set__[T](
            self, instance: Delay[Any], value: Callable[[T], T], /
        ) -> None: ...
@wrapper
def fun1(arg: str) -> str:
    return arg

# `Delay[str]` (parameterised by concrete type `str`)
reveal_type(fun1().wrapped)  # "def (builtins.str) -> builtins.str"
@wrapper
def fun2[T](arg: T) -> T:
    return arg

# `Delay[Never]` (can't fulfil parameterisation)
reveal_type(fun2().wrapped)  # def [R](R) -> R

Full solution below:

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Never, Self, overload


if TYPE_CHECKING:

    class Wrapped:
        @overload  # type: ignore[no-overload-impl]
        def __get__(self, instance: None, owner: type[object], /) -> Self: ...
        @overload
        def __get__[R](
            self, instance: Delay[Never], owner: type[Delay[Never]], /
        ) -> Callable[[R], R]: 
            """
            Can't be parameterised by a concrete type, return a callable which
            just returns the same type as it receives
            """
        @overload
        def __get__[T](
            self, instance: Delay[T], owner: type[Delay[T]], /
        ) -> Callable[[T], T]: 
            """
            Can be parameterised by a concrete type, return a callable which
            receives and returns this concrete type
            """
        def __set__[T](
            self, instance: Delay[Any], value: Callable[[T], T], /
        ) -> None: ...


class Delay[T]:
    if TYPE_CHECKING:
        wrapped = Wrapped()

    def __init__(self, wrapped: Callable[[T], T]):
        self.wrapped = wrapped


def wrapper[T](wrapped: Callable[[T], T]) -> Callable[[], Delay[T]]:
    def wrapping() -> Delay[T]:
        return Delay(wrapped)

    return wrapping


@wrapper
def fun1(arg: str) -> str:
    return arg


reveal_type(fun1)  # mypy says: "def () -> Delay[builtins.str]"
reveal_type(fun1())  # mypy says: "Delay[builtins.str]"
reveal_type(fun1().wrapped)  # mypy says: "def (builtins.str) -> builtins.str"
reveal_type(fun1().wrapped("test"))  # mypy says: "builtins.str"
reveal_type(fun1().wrapped(1))  # Error


@wrapper
def fun2[T](arg: T) -> T:
    return arg


reveal_type(fun2)
reveal_type(fun2())
reveal_type(fun2().wrapped)  # def [R](R) -> R
reveal_type(fun2().wrapped("test"))  # str
reveal_type(fun2().wrapped(1))  # int
like image 183
dROOOze Avatar answered Sep 04 '25 17:09

dROOOze