Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Comparable types with mypy

I'm trying to create a generic class to express that a value has lower and upper bounds, and to enforce those bounds.

from typing import Any, Optional, TypeVar

T = TypeVar("T")

class Bounded(object):
    def __init__(self, minValue: T, maxValue: T) -> None:
        assert minValue <= maxValue
        self.__minValue = minValue
        self.__maxValue = maxValue

However, mypy complains that:

error: Unsupported left operand type for <= ("T")

Apparently typing module doesn't allow me to express this (although it looks like adding Comparable might happen in the future).

I think it would be enough to check that object has __eq__ and __lt__ methods (for my use case at least). Is there any way to currently express this requirement in Python so that Mypy would understand it?

like image 399
Hannes Karppila Avatar asked Dec 25 '17 01:12

Hannes Karppila


1 Answers

After a bit more research, I found a solution: Protocols. Since they are not fully stabile (yet of Python 3.6), they have to be imported from the typing_extensions modules.

import typing
from typing import Any
from typing_extensions import Protocol
from abc import abstractmethod

C = typing.TypeVar("C", bound="Comparable")

class Comparable(Protocol):
    @abstractmethod
    def __eq__(self, other: Any) -> bool:
        pass

    @abstractmethod
    def __lt__(self: C, other: C) -> bool:
        pass

    def __gt__(self: C, other: C) -> bool:
        return (not self < other) and self != other

    def __le__(self: C, other: C) -> bool:
        return self < other or self == other

    def __ge__(self: C, other: C) -> bool:
        return (not self < other)

Now we can define our type as:

C = typing.TypeVar("C", bound=Comparable)

class Bounded(object):
    def __init__(self, minValue: C, maxValue: C) -> None:
        assert minValue <= maxValue
        self.__minValue = minValue
        self.__maxValue = maxValue

And Mypy is happy:

from functools import total_ordering

@total_ordering
class Test(object):
    def __init__(self, value):
        self.value = value
    def __eq__(self, other):
        return self.value == other.value
    def __lt__(self, other):
        return self.value < other.value

FBounded(Test(1), Test(10))
FBounded(1, 10)
like image 183
Hannes Karppila Avatar answered Sep 22 '22 06:09

Hannes Karppila