Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Distinguishing between Pydantic Models with same fields

I'm using Pydantic to define hierarchical data in which there are models with identical attributes.

However, when I save and load these models, Pydantic can no longer distinguish which model was used and picks the first one in the field type annotation.

I understand that this is expected behavior based on the documentation. However, the class type information is important to my application.

What is the recommended way to distinguish between different classes in Pydantic? One hack is to simply add an extraneous field to one of the models, but I'd like to find a more elegant solution.

See the simplified example below: container is initialized with data of type DataB, but after exporting and loading, the new container has data of type DataA as it's the first element in the type declaration of container.data.

Thanks for your help!

from abc import ABC
from pydantic import BaseModel #pydantic 1.8.2
from typing import Union

class Data(BaseModel, ABC):
    """ base class for a Member """
    number: float

class DataA(Data):
    """ A type of Data"""
    pass

class DataB(Data):
    """ Another type of Data """
    pass

class Container(BaseModel):
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]

# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)

# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)

# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataA
like image 431
twhughes Avatar asked Sep 24 '21 23:09

twhughes


2 Answers

As correctly noted in the comments, without storing additional information models cannot be distinguished when parsing.

As of today (pydantic v1.8.2), the most canonical way to distinguish models when parsing in a Union (in case of ambiguity) is to explicitly add a type specifier Literal. It will look like this:

from abc import ABC
from pydantic import BaseModel
from typing import Union, Literal

class Data(BaseModel, ABC):
    """ base class for a Member """
    number: float


class DataA(Data):
    """ A type of Data"""
    tag: Literal['A'] = 'A'


class DataB(Data):
    """ Another type of Data """
    tag: Literal['B'] = 'B'


class Container(BaseModel):
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]


# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)

# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)


# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataB

This method can be automated, but you can use it at your own responsibility, since it breaks static typing and uses objects that may change in future versions:

from pydantic.fields import ModelField

class Data(BaseModel, ABC):
    """ base class for a Member """
    number: float

    def __init_subclass__(cls, **kwargs):
        name = 'tag'
        value = cls.__name__
        annotation = Literal[value]

        tag_field = ModelField.infer(name=name, value=value, annotation=annotation, class_validators=None, config=cls.__config__)
        cls.__fields__[name] = tag_field
        cls.__annotations__[name] = annotation


class DataA(Data):
    """ A type of Data"""
    pass


class DataB(Data):
    """ Another type of Data """
    pass
like image 90
alex_noname Avatar answered Oct 28 '22 20:10

alex_noname


Just wanted to take the opportunity to list another possible alternative here to pydantic - which already supports this use case very well, as per below answer.

I am the creator and maintainer of a relatively newer and lesser-known JSON serialization library, the Dataclass Wizard - which relies on the Python dataclasses module to perform its magic. As of the latest version, 0.14.0, the dataclass-wizard now supports dataclasses within Union types. Previously, it did not support dataclasses within Union types at all, which was kind of a glaring omission, and something on my "to-do" list of things to (eventually) add support for.

As of the latest, it should now support defining dataclasses within Union types. The reason it did not generally work before, is because the data being de-serialized is often a JSON object, which only knows simple types such as arrays and dictionaries, for example. A dict type would not otherwise match any of the Union[Data1, Data2] types, even if the object had all the correct dataclass fields as keys. This is simply because it doesn't compare the dict object against each of the dataclass fields in the Union types, though that might change in a future release.

So in any case, here is a simple example to demonstrate the usage of dataclasses in Union types, using a class inheritance model with the JSONWizard mixin class:

With Class Inheritance
from abc import ABC
from dataclasses import dataclass
from typing import Union

from dataclass_wizard import JSONWizard


@dataclass
class Data(ABC):
    """ base class for a Member """
    number: float


class DataA(Data, JSONWizard):
    """ A type of Data"""

    class _(JSONWizard.Meta):
        """
        This defines a custom tag that uniquely identifies the dataclass.
        """
        tag = 'A'


class DataB(Data, JSONWizard):
    """ Another type of Data """

    class _(JSONWizard.Meta):
        """
        This defines a custom tag that uniquely identifies the dataclass.
        """
        tag = 'B'


@dataclass
class Container(JSONWizard):
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]

The usage is shown below, and is again pretty straightforward. It relies on a special __tag__ key set in a dictionary or JSON object to marshal it into the correct dataclass, based on the Meta.tag value for that class, that we have set up above.

print('== Load with DataA ==')

input_dict = {
    'data': {
        'number': '1.0',
        '__tag__': 'A'
    }
}

# De-serialize the `dict` object to a `Container` instance.
container = Container.from_dict(input_dict)

print(repr(container))
# prints:
#   Container(data=DataA(number=1.0))

# Show the prettified JSON representation of the instance.
print(container)

# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataA

print()

print('== Load with DataB ==')

# initialize container with DataB
data_b = DataB(number=2.0)
container = Container(data=data_b)

print(repr(container))
# prints:
#   Container(data=DataB(number=2.0))

# Show the prettified JSON representation of the instance.
print(container)

# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB

# Assert we end up with the same instance when serializing and de-serializing
# our data.
string = container.to_json()
assert container == Container.from_json(string)
Without Class Inheritance

Here is the same example as above, but with relying solely on dataclasses, without using any special class inheritance model:

from abc import ABC
from dataclasses import dataclass
from typing import Union

from dataclass_wizard import asdict, fromdict, LoadMeta


@dataclass
class Data(ABC):
    """ base class for a Member """
    number: float


class DataA(Data):
    """ A type of Data"""


class DataB(Data):
    """ Another type of Data """


@dataclass
class Container:
    """ container holds a subclass of Data """
    data: Union[DataA, DataB]


# Setup tags for the dataclasses. This can be passed into either
# `LoadMeta` or `DumpMeta`.
#
# Note that I'm not a fan of this syntax either, so it might change. I was
# thinking of something more explicit, like `LoadMeta(...).bind_to(class)`
LoadMeta(DataA, tag='A')
LoadMeta(DataB, tag='B')

# The rest is the same as before.

# initialize container with DataB
data = DataB(number=2.0)
container = Container(data=data)

print(repr(container))
# prints:
#   Container(data=DataB(number=2.0))

# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB

# Assert we end up with the same data when serializing and de-serializing.
out_dict = asdict(container)
assert container == fromdict(Container, out_dict)
like image 29
rv.kvetch Avatar answered Oct 28 '22 19:10

rv.kvetch