Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to check overloaded single dispatch with mypy

I have to overload a python function with parameters of different sizes. I can do it and I can do run-time check of incorrect calls, see below. Is it possible to make mypy check for incorrect usage? I am using python-3.7.3 and mypy-0.770.

from functools import singledispatch

class A:
    def __init__(self) -> None:
        self.dummy = 0

class B:
    def __init__(self) -> None:
        self.dummy = 1


@singledispatch
def base(arg, *args) -> None:
    del arg, args # Not used
    # It should not be called
    print("unexpected call to dummy base")

@base.register
def xa(arg: A, val0: int, sval: str="") -> None:
    del arg # Not used
    if not isinstance(sval, str) :
        print("bad call to a-variant")
        return
        
        
    print("called a-variant: val0={} sval='{}' ".format(val0, sval))

@base.register
def xb(arg: B, val0: int, val1: int, sval: str ="") -> None:
    del arg # Not used
    print("called b-variant: val0={} val1={} sval='{}'".format(val0, val1, sval))

def test() -> None:
    base(A(), 10)
    base(A(), 11, "xxx")
    base(B(), 10, 20)
    base(B(), 11, 22, "yyy")

    # Bad
    base(A(), 10, 20)

like image 707
uuu777 Avatar asked Dec 07 '25 11:12

uuu777


1 Answers

I was able to solve this problem by using typing.overload.

Update: it took some time to resolve mypy errors that overloaded function does not accept all possible arguments.

from typing import overload, Optional, Union

class A:
    def __init__(self) -> None:
        self.dummy = 0

class B:
    def __init__(self) -> None:
        self.dummy = 1


@overload
def base(arg: A, val0: int, val1: str = "") -> None:
    ...

@overload
def base(arg: B, val0: int, val1: int, val2:str = "") -> None:
    ...


def base(arg: Union[A, B], val0: int, val1: Union[str, int] = "", val2: Optional[str] = None) -> None:
    if isinstance(arg, A):
        print("got A")
        return

    if isinstance(arg, B):
        print("got B")
        return

    print("unexpected")


def test() -> None:
    base(A(), 10, "xxx")
    base(B(), 10, 20, "yyy")

    # Bad
    base(A(), "xxx", 10, 20)
    base(B(), "xxx", 10, "ttt")
    base(25)
like image 167
uuu777 Avatar answered Dec 09 '25 00:12

uuu777



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!