I want my dataclass to have a field that can either be provided manually, or if it isn't, it is inferred at initialization from the other fields. MWE:
from collections.abc import Sized
from dataclasses import dataclass
from typing import Optional
@dataclass
class Foo:
data: Sized
index: Optional[list[int]] = None
def __post_init__(self):
if self.index is None:
self.index = list(range(len(self.data)))
reveal_type(Foo.index) # Union[None, list[int]]
reveal_type(Foo([1,2,3]).index) # Union[None, list[int]]
How can this be implemented in a way such that:
mypy type checkingindex is guaranteed to be of type list[int]I considered using default_factory(list), however, then how does one distinguish the User passing index=[] from the sentinel value? Is there a proper solution besides doing
index: list[int] = None # type: ignore[assignment]
Use NotImplemented
from collections.abc import Sized
from dataclasses import dataclass
@dataclass
class Foo:
data: Sized
index: list[int] = NotImplemented
def __post_init__(self):
if self.index is NotImplemented:
self.index = list(range(len(self.data)))
You can have the default_factory return a list with a sentinel object as its only element. You just need to make sure that the sentinel is an instance of int, otherwise mypy will complain. Luckily we have identity comparisons to ensure that the check in __post_init__ is always correct.
from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
@staticmethod
def _idx_sentinel_factory() -> list[int]:
return [Foo._idx_sentinel]
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))
I put the entire factory and sentinel logic inside of Foo, but if you don't like that, you can also factor it out:
from collections.abc import Sized
from dataclasses import dataclass, field
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
def _idx_sentinel_factory() -> list[int]:
return [_idx_sentinel]
@dataclass
class Foo:
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is _idx_sentinel:
self.index = list(range(len(self.data)))
EDIT: Inspired by the comment from @SUTerliakov, here is slightly less verbose version that still satisfies type checkers and linters using a lambda-expression instead of a named function:
from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
data: Sized
index: list[int] = field(default_factory=lambda: [Foo._idx_sentinel])
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With