Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Registering discriminated union automatically

Using pydantic 1.10.7 and python 3.11.2

I have a recursive Pydantic model and I would like to deserialize each types properly using discriminated union.

from pydantic import BaseModel, Field

class Base(BaseModel):
    kind: str
    sub_models: Annotated[
        List[Union[A,B]], 
        Field(
            default_factory=list,
            discriminator="kind"
        )
    ]

class A(Base):
    kind: Literal["a"]
    a_field: str

class B(Base):
    kind: Literal["b"]
    b_field: str

I would like to automatically register the subclasses in a way Pydantic will be able to understand, like so

from pydantic import BaseModel, Field

B = TypeVar("B", bound="Base")

class Base(BaseModel):
    kind: str
    sub_models: Annotated[
        List[B], 
        Field(
            default_factory=list,
            discriminator="kind"
        )
    ]
    _subs: Set[Type[B]] = set()

    def __init_subclass__(cls, /, **kwargs):
        Base._subs.add(cls)
        cls.__annotations__["kind"] = Literal[cls.__name__.lower()]  # <- works
        # list comprehension in a type definition is not valid,
        Base.__annotations__["sub_models"] = List[Union[subclass for subclass in Base._subs]]

class A(Base):
    a_field: str

class B(Base):
    b_field: str

Any idea how to have discriminated union configured dynamically?

I have tried registering manually the subclasses, but it involves a circular dependency of type hints and I need to "not forget" to add each new type to the union.

like image 391
abstrus Avatar asked Oct 16 '25 02:10

abstrus


1 Answers

The challenge here is that a lot of the heavy lifting of Pydantic model creation is done by the metaclass, specifically its __new__ method.

Simply re-defining sub_models and kind in the __annotations__ dictionary will not be enough for those changes to affect the actual fields because those are created and configured by the metaclass before __init_subclass__ is called.

So if we don't want to go through the trouble of overriding the metaclass, we need some function to set a field on an existing (i.e. fully defined) model.


Solution

This seems to work:

from __future__ import annotations
from typing import Annotated, Any, ClassVar, Literal, Union

from pydantic import BaseModel, Field
from pydantic.fields import ModelField, Undefined


def set_field_on_model(
    model: type[BaseModel],
    name: str,
    annotation: Any,
    value: Any = Undefined,
) -> None:
    """
    Inspired heavily by what is done inside `ModelMetaclass.__new__`.
    May be incomplete!
    """
    model.__fields__[name] = ModelField.infer(
        name=name,
        value=value,
        annotation=annotation,
        class_validators=None,
        config=model.__config__,
    )
    model.__annotations__[name] = annotation

# more code below...

Now we can use that function inside our __init_subclass__ method to dynamically update the model fields we want:

class Base(BaseModel):
    _subs: ClassVar[set[type["Base"]]] = set()
    sub_models: list["Base"] = Field(default_factory=list)
    kind: str

    def __init_subclass__(cls, /, **kwargs: object) -> None:
        super().__init_subclass__(**kwargs)
        Base._subs.add(cls)
        set_field_on_model(cls, "kind", Literal[cls.__name__.lower()])
        if len(Base._subs) < 2:
            return  # discriminated union requires at least two distinct types
        sub_models_annotation = list[  # type: ignore[index, misc, valid-type]
            Annotated[
                Union[tuple(Base._subs)],
                Field(discriminator="kind"),
            ]
        ]
        for model in {Base} | Base._subs:
            set_field_on_model(
                model,
                "sub_models",
                sub_models_annotation,
                Field(default_factory=list),
            )

Notes

  • To "dynamically" define a type Union, you can just pass a tuple of types to it because at runtime Union[X, Y] will be equivalent to Union[(X, Y)]. (see this answer)
  • For a proper definition of your sub_models list, you have to define the discriminated union inside it, i.e. you must pass the Annotated type as the type argument to list like list[Annotated[Union[...], Field(...)]]. The discriminated parameter relates to the item type of the list, not the list itself.
  • Conversely, you must set the default_factory on the sub_models list field itself, not in the Annotated union of types inside the list. That is why I pass a separate FieldInfo as the value argument to set_field_on_model for the sub_models field.
  • It is not enough to just modify the fields of the Base model during subclassing. To make everything consistent, we must also update the __fields__ for every subclass of Base, when we create a new one. That is why I loop over Base._subs (and Base itself).
  • Since a discriminated union requires at least two types to work, we need an escape hatch for when we subclass Base for the first time. If we didn't check, whether the number of sub-models is at least 2, we would get an error from Pydantic during our first subclass creation.
  • This implementation for setting/overriding model fields is based on my moderate but limited understanding of the intricacies of the Pydantic machinery. It may well be that this is incomplete and there are additional steps you would have to take, to ensure everything remains consistent.
  • In general, when you start fiddling with __fields__ after the model is created, all bets are off. Meaning this is not really documented there are no guarantees that things will not break with the next release. (Not to mention Pydantic v2.)

Demo/Test

Let's define (at least) two sub-models and try this out:

from pydantic import ValidationError

# ... import Base


class A(Base):
    a_field: str


class B(Base):
    b_field: str


test_data = {
    "kind": "a",
    "a_field": "foo",
    "sub_models": [
        {"kind": "a", "a_field": "bar"},
        {"kind": "b", "b_field": "baz"},
    ],
}
obj = A.parse_obj(test_data)
print(type(obj.sub_models[0]), type(obj.sub_models[1]))
print(obj.json(indent=4))

# Now introduce an error:
test_data["sub_models"][1]["kind"] = "a"
try:
    A.parse_obj(test_data)
except ValidationError as err:
    print(err.json(indent=4))

Output:

<class '__main__.A'> <class '__main__.B'>
{
    "sub_models": [
        {
            "sub_models": [],
            "kind": "a",
            "a_field": "bar"
        },
        {
            "sub_models": [],
            "kind": "b",
            "b_field": "baz"
        }
    ],
    "kind": "a",
    "a_field": "foo"
}
[
    {
        "loc": [
            "sub_models",
            1,
            "A",
            "a_field"
        ],
        "msg": "field required",
        "type": "value_error.missing"
    }
]

That is the output we want. The two objects in the outer sub_models list are actually instances of distinct classes/models A and B. Their fields are all correctly assigned. And if the kind does not match the expected fields (as in the last example), we get a validation error.

like image 51
Daniil Fajnberg Avatar answered Oct 17 '25 14:10

Daniil Fajnberg



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!