Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to type hint function to be compatible with numpy

Source code of example.py:

from typing import Union, Any
import numpy as np

Number = Union[int, float, np.floating[Any]]


def add_one(num: Number) -> Number:
    return num + 1


inputs = [1, 2, 3]
outputs = [add_one(n) for n in inputs]

avg = np.mean(outputs)

Run mypy:

mypy example.py
src/example.py:14: error: Argument 1 to "mean" has incompatible type "List[Union[float, floating[Any]]]"; expected "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"
Found 1 error in 1 file (checked 1 source file)

I can change everything to np.floating[Any] which fixes the numpy issue, but then I have to cast primitives to np.float32(...):

from typing import Any
import numpy as np


def add_one(num: np.floating[Any]) -> np.floating[Any]:
    return num + 1


inputs = [1, 2, 3]
outputs = [add_one(np.float32(n)) for n in inputs]

avg = np.mean(outputs)

Is there a correct way to type hint the add_one function so that its outputs are compatible with numpy functions such as np.mean without breaking compatibility with the python primitive types? The end goal is to be able to use it like this:

inputs = [1, 2, 3]
outputs = [add_one(n) for n in inputs]
avg = np.mean(outputs)
like image 407
kym Avatar asked Mar 01 '26 09:03

kym


1 Answers

MyPy is beeing too strict on complex type hierarchy. (possibly due to an ill-defined type for numpy.mean's 1st arg.) (( maybe a bug in MyPy's stubs/type definitions?))

The code runs, is just a linter error, hard to tell who's to blame.

You have a few options:

1. Just ignore typecheck for the call to mean

avg = np.mean(outputs) # type: ignore

A List[T] should be an array_like, as NumPy can handle it. But for some reason MyPy sometimes fails to make that equivalence for custom types. The code is correct, so you could just ignore types for that call.


2. Keep type hint as-is. Use np.array

Altought MyPy fails to recognize that List[T] is array_like for complex and abstract types T, it won't complain about an np.array

The following edited script works fine in both variations, and gives no errors nor warnings from MyPy:

from typing import Union, Any
import numpy as np

Number = Union[int, float, np.ndarray]

def add_one(num: Number) -> np.ndarray: 
    return np.array(num) + 1 #Using np.array


inputs = [1, 2, 3]
outputs = [add_one(n) for n in inputs] 

avg = np.mean(outputs)

print("Outputs:", outputs) # Outputs: [2, 3, 4]
print("Average:", avg) # Average: 3.0

Running: mypy ./example.py

Output: Success: no issues found in 1 source file

from typing import Union, Any
import numpy as np

Number = Union[int, float, np.floating[Any]]

def add_one(num: Number) -> Number:
    return num + 1


inputs = [1, 2, 3]
outputs = np.array([add_one(n) for n in inputs]) #Using np.array

avg = np.mean(outputs)

print("Outputs:", outputs) # Outputs: [2, 3, 4]
print("Average:", avg) # Average: 3.0

Running: mypy ./example.py

Output: Success: no issues found in 1 source file

Note that all of Numpys's examples for mean pass np.array and not List as a attribute


3. Define custom Protocol that supports aritmethic and numerical operations and conversion to primitive types and Numpy types

It'd be a good idea to ensure it behaves well with other number classes, I've used number package's Real class, as it is very braodly compatible, but you could further constrian it.


import numpy as np
from typing import Protocol, TypeVar, Union
from numbers import Number as Real

class Number(Protocol):
    # Arithmetic operations
    def __add__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __radd__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __sub__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __rsub__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __mul__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __rmul__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __truediv__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __rtruediv__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __floordiv__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __rfloordiv__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __mod__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __rmod__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __pow__(self, other: Union[Real, "Number"]) -> "Number": ...
    def __rpow__(self, other: Union[Real, "Number"]) -> "Number": ...

    # Unary operations
    def __neg__(self) -> "Number": ...
    def __pos__(self) -> "Number": ...
    def __abs__(self) -> "Number": ...

    # Other comparisons
    def __eq__(self, other: Union[Real, "Number"]) -> bool: ...
    def __ne__(self, other: Union[Real, "Number"]) -> bool: ...
    def __lt__(self, other: Union[Real, "Number"]) -> bool: ...
    def __le__(self, other: Union[Real, "Number"]) -> bool: ...
    def __gt__(self, other: Union[Real, "Number"]) -> bool: ...
    def __ge__(self, other: Union[Real, "Number"]) -> bool: ...

    # Conversions
    def __float__(self) -> float: ...
    def __int__(self) -> int: ...
    def __index__(self) -> int: ...

    # Implicit conversion to NumPy types
    def __array__(self) -> np.ndarray: ...
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: ...

N = TypeVar("N", bound=Number)

Then you should be able to use it like this:

def add_one(num: N) -> N: 
    return num + 1 

inputs = [1, 2, 3]
outputs = [add_one(n) for n in inputs] 

avg = np.mean(outputs)

print("Outputs:", outputs) # Outputs: [2, 3, 4]
print("Average:", avg) # Average: 3.0

Running: mypy ./example.py

Output: Success: no issues found in 1 source file


4. Define custom implementation of List[Number] that is explicitly array_like.

Alternative, if you're running Python >= 3.12, and MyPy >= 1.12 which offers support for PEP 695 you could define a list subtype that implements the array interface:

...
T = TypeVar("T", bound = Number)
class NArray[T] (list[T]):
    def __array__(self): return np.array(self)

...

outputs : NArray[Number]  = NArray([add_one(n) for n in inputs]) #Using NArray

Running: mypy ./example.py

Output: Success: no issues found in 1 source file


Explanation

The problem is MyPys handling of the List[Number], not the underlying types contained in it

As per Numpy's reference numpy.mean is defined as numpy.mean(a, axis=None, dtype=None, out=None, keepdims=<no value>, *, where=<no value>) where a is of type array_like

In it's glossary, Numpy defines array_like to be:

array_like Any scalar or sequence that can be interpreted as an ndarray. In addition to > ndarrays and scalars this category includes lists (possibly nested and with different element types) and tuples. Any argument accepted by numpy.array is array_like.

And

When you look at the numpy.core.fromnumeric.pyi's definitions for overloaded function mean you find that there are 6 actual overloads of mean:


@overload
def mean(
    a: _ArrayLikeFloat_co,
    axis: None = ...,
    dtype: None = ...,
    out: None = ...,
    keepdims: Literal[False] = ...,
    *,
    where: _ArrayLikeBool_co = ...,
) -> floating[Any]: ...
@overload
def mean(
    a: _ArrayLikeComplex_co,
    axis: None = ...,
    dtype: None = ...,
    out: None = ...,
    keepdims: Literal[False] = ...,
    *,
    where: _ArrayLikeBool_co = ...,
) -> complexfloating[Any, Any]: ...
@overload
def mean(
    a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
    axis: None | _ShapeLike = ...,
    dtype: None = ...,
    out: None = ...,
    keepdims: bool = ...,
    *,
    where: _ArrayLikeBool_co = ...,
) -> Any: ...
@overload
def mean(
    a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
    axis: None = ...,
    dtype: _DTypeLike[_SCT] = ...,
    out: None = ...,
    keepdims: Literal[False] = ...,
    *,
    where: _ArrayLikeBool_co = ...,
) -> _SCT: ...
@overload
def mean(
    a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
    axis: None | _ShapeLike = ...,
    dtype: DTypeLike = ...,
    out: None = ...,
    keepdims: bool = ...,
    *,
    where: _ArrayLikeBool_co = ...,
) -> Any: ...
@overload
def mean(
    a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
    axis: None | _ShapeLike = ...,
    dtype: DTypeLike = ...,
    out: _ArrayType = ...,
    keepdims: bool = ...,
    *,
    where: _ArrayLikeBool_co = ...,
) -> _ArrayType: ...

So... what gives? List[Number] should be compatible with array_like, right?

Well, let's look closely at the error MyPy gives:

error: Argument 1 to "mean" has incompatible type "list[int | float | floating[Any]]"; expected "_SupportsArray[dtype[bool_ | integer[Any] | floating[Any]]] | _NestedSequence[_SupportsArray[dtype[bool_ | integer[Any] | floating[Any]]]] | bool | int | float | _NestedSequence[bool | int | float]"  [arg-type]

You'll notice that althought it's clearly referincing a Union of obscure abstract types (and possibly auto-generated types as well) the primary type in the Union is _SupportsArray, which is a pretty abstract protocol defined in numpy._typing._array_like:

# The `_SupportsArray` protocol only cares about the default dtype
# (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
# array.
# Concrete implementations of the protocol are responsible for adding
# any and all remaining overloads
@runtime_checkable
class _SupportsArray(Protocol[_DType_co]):
    def __array__(self) -> ndarray[Any, _DType_co]: ...

The target we're looking for should be:

_SupportsArray[dtype[bool_ | integer[Any] | floating[Any]]]

¡That's the problem! linter expects a _SupportsArray of bool_, integer[Any], or floating[Any]

If Number where Union[bool_ | integer[Any] | floating[Any], MyPy wouldn't complain.

But that's a bug in current implementation of NumPy, the expected type should be _SupportsArray[Any] as the documentation states.

like image 188
HernanATN Avatar answered Mar 04 '26 09:03

HernanATN



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!