Using pydantic 1.10.7 and python 3.11.2
I have a recursive Pydantic model and I would like to deserialize each types properly using discriminated union.
from pydantic import BaseModel, Field
class Base(BaseModel):
kind: str
sub_models: Annotated[
List[Union[A,B]],
Field(
default_factory=list,
discriminator="kind"
)
]
class A(Base):
kind: Literal["a"]
a_field: str
class B(Base):
kind: Literal["b"]
b_field: str
I would like to automatically register the subclasses in a way Pydantic will be able to understand, like so
from pydantic import BaseModel, Field
B = TypeVar("B", bound="Base")
class Base(BaseModel):
kind: str
sub_models: Annotated[
List[B],
Field(
default_factory=list,
discriminator="kind"
)
]
_subs: Set[Type[B]] = set()
def __init_subclass__(cls, /, **kwargs):
Base._subs.add(cls)
cls.__annotations__["kind"] = Literal[cls.__name__.lower()] # <- works
# list comprehension in a type definition is not valid,
Base.__annotations__["sub_models"] = List[Union[subclass for subclass in Base._subs]]
class A(Base):
a_field: str
class B(Base):
b_field: str
Any idea how to have discriminated union configured dynamically?
I have tried registering manually the subclasses, but it involves a circular dependency of type hints and I need to "not forget" to add each new type to the union.
The challenge here is that a lot of the heavy lifting of Pydantic model creation is done by the metaclass, specifically its __new__ method.
Simply re-defining sub_models and kind in the __annotations__ dictionary will not be enough for those changes to affect the actual fields because those are created and configured by the metaclass before __init_subclass__ is called.
So if we don't want to go through the trouble of overriding the metaclass, we need some function to set a field on an existing (i.e. fully defined) model.
This seems to work:
from __future__ import annotations
from typing import Annotated, Any, ClassVar, Literal, Union
from pydantic import BaseModel, Field
from pydantic.fields import ModelField, Undefined
def set_field_on_model(
model: type[BaseModel],
name: str,
annotation: Any,
value: Any = Undefined,
) -> None:
"""
Inspired heavily by what is done inside `ModelMetaclass.__new__`.
May be incomplete!
"""
model.__fields__[name] = ModelField.infer(
name=name,
value=value,
annotation=annotation,
class_validators=None,
config=model.__config__,
)
model.__annotations__[name] = annotation
# more code below...
Now we can use that function inside our __init_subclass__ method to dynamically update the model fields we want:
class Base(BaseModel):
_subs: ClassVar[set[type["Base"]]] = set()
sub_models: list["Base"] = Field(default_factory=list)
kind: str
def __init_subclass__(cls, /, **kwargs: object) -> None:
super().__init_subclass__(**kwargs)
Base._subs.add(cls)
set_field_on_model(cls, "kind", Literal[cls.__name__.lower()])
if len(Base._subs) < 2:
return # discriminated union requires at least two distinct types
sub_models_annotation = list[ # type: ignore[index, misc, valid-type]
Annotated[
Union[tuple(Base._subs)],
Field(discriminator="kind"),
]
]
for model in {Base} | Base._subs:
set_field_on_model(
model,
"sub_models",
sub_models_annotation,
Field(default_factory=list),
)
Union, you can just pass a tuple of types to it because at runtime Union[X, Y] will be equivalent to Union[(X, Y)]. (see this answer)sub_models list, you have to define the discriminated union inside it, i.e. you must pass the Annotated type as the type argument to list like list[Annotated[Union[...], Field(...)]]. The discriminated parameter relates to the item type of the list, not the list itself.default_factory on the sub_models list field itself, not in the Annotated union of types inside the list. That is why I pass a separate FieldInfo as the value argument to set_field_on_model for the sub_models field.Base model during subclassing. To make everything consistent, we must also update the __fields__ for every subclass of Base, when we create a new one. That is why I loop over Base._subs (and Base itself).Base for the first time. If we didn't check, whether the number of sub-models is at least 2, we would get an error from Pydantic during our first subclass creation.__fields__ after the model is created, all bets are off. Meaning this is not really documented there are no guarantees that things will not break with the next release. (Not to mention Pydantic v2.)Let's define (at least) two sub-models and try this out:
from pydantic import ValidationError
# ... import Base
class A(Base):
a_field: str
class B(Base):
b_field: str
test_data = {
"kind": "a",
"a_field": "foo",
"sub_models": [
{"kind": "a", "a_field": "bar"},
{"kind": "b", "b_field": "baz"},
],
}
obj = A.parse_obj(test_data)
print(type(obj.sub_models[0]), type(obj.sub_models[1]))
print(obj.json(indent=4))
# Now introduce an error:
test_data["sub_models"][1]["kind"] = "a"
try:
A.parse_obj(test_data)
except ValidationError as err:
print(err.json(indent=4))
Output:
<class '__main__.A'> <class '__main__.B'>
{
"sub_models": [
{
"sub_models": [],
"kind": "a",
"a_field": "bar"
},
{
"sub_models": [],
"kind": "b",
"b_field": "baz"
}
],
"kind": "a",
"a_field": "foo"
}
[
{
"loc": [
"sub_models",
1,
"A",
"a_field"
],
"msg": "field required",
"type": "value_error.missing"
}
]
That is the output we want. The two objects in the outer sub_models list are actually instances of distinct classes/models A and B. Their fields are all correctly assigned. And if the kind does not match the expected fields (as in the last example), we get a validation error.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With