How can I overload the get_data method below to return the correct type based on the init value of data_type instead of returning a union of both types?
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> WoodData | ConcreteData:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
I was thinking this could be done by specifying a generic for Foo. But I'm unsure on implementation details.
I'd prefer not to pass WoodData/ConcreteData directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood or concrete.
To illustrate that last point, I know I could add a generic that takes one of the two return types like so:
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> MY_RETURN_TYPE:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
But imagine I have tons of methods conditionally returning different types based on the value of data_type. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.
Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.
Ok, for this solution, you annotate self with the generic type you want, both mypy and pyright give similar outputs for reveal_type (i.e., it works with the base class but not the subclass):
from typing import Literal, overload, TypeVar
class WoodData: ...
class ConcreteData: ...
class Foo[T:(Literal['wood'], Literal['concrete'])]:
data_type: T
def __init__(self, data_type: T) -> None:
self.data_type = data_type
@overload
def get_data(self: "Foo[Literal['wood']]") -> WoodData:
...
@overload
def get_data(self: "Foo[Literal['concrete']]") -> ConcreteData:
...
@overload
def get_data(self) -> WoodData | ConcreteData:
...
def get_data(self):
if self.data_type == "wood":
return WoodData()
return ConcreteData()
@overload
def bar(self: "Foo[Literal['wood']]") -> int:
...
@overload
def bar(self: "Foo[Literal['concrete']]") -> str:
...
@overload
def bar(self) -> int | str:
...
def bar(self):
if self.data_type == "wood":
return 42
return "42"
reveal_type(Foo('wood').get_data()) # main.py:32: note: Revealed type is "__main__.WoodData"
reveal_type(Foo('concrete').get_data()) # main.py:33: note: Revealed type is "__main__.ConcreteData"
reveal_type(Foo('wood').bar()) # main.py:34: note: Revealed type is "builtins.int"
reveal_type(Foo('concrete').bar()) # main.py:35: note: Revealed type is "builtins.str"
class Bar[T:(Literal['wood'], Literal['concrete'])](Foo[T]):
pass
# works with inheritance too
reveal_type(Bar('wood').get_data()) # main.py:41: note: Revealed type is "__main__.WoodData"
reveal_type(Bar('concrete').get_data()) # main.py:41: note: Revealed type is "__main__.ConcreteData"
reveal_type(Bar('wood').bar()) # main.py:41: note: Revealed type is "builtins.int"
reveal_type(Bar('concrete').bar()) # main.py:41: note: Revealed type is "builtins.str"
However, mypy won't type check the body of the implementation, and pyright seems to be reporting erroneous errors for the body...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With