Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using __new__ in inherited dataclasses

Suppose I have the following code that is used to handle links between individuals and countries:

from dataclasses import dataclass

@dataclass
class Country:
    iso2 : str
    iso3 : str
    name : str

countries = [ Country('AW','ABW','Aruba'),
              Country('AF','AFG','Afghanistan'),
              Country('AO','AGO','Angola')]
countries_by_iso2 = {c.iso2 : c for c in countries}
countries_by_iso3 = {c.iso3 : c for c in countries}

@dataclass
class CountryLink:
    person_id : int
    country : Country

country_links = [ CountryLink(123, countries_by_iso2['AW']),
                  CountryLink(456, countries_by_iso3['AFG']),
                  CountryLink(789, countries_by_iso2['AO'])]

print(country_links[0].country.name)

This is all working fine, but I decide that I want to make it a bit less clunky to be able to handle the different forms of input. I also want to use __new__ to make sure that we are getting a valid ISO code each time, and I want to object to fail to be created in that case. I therefore add a couple new classes that inherit from this:

@dataclass
class CountryLinkFromISO2(CountryLink):
    def __new__(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        new_obj = super().__new__(cls)
        new_obj.country = countries_by_iso2[iso2]
        return new_obj

@dataclass
class CountryLinkFromISO3(CountryLink):
    def __new__(cls, person_id : int, iso3 : str):
        if iso3 not in countries_by_iso3:
            return None
        new_obj = super().__new__(cls)
        new_obj.country = countries_by_iso3[iso3]
        return new_obj

country_links = [ CountryLinkFromISO2(123, 'AW'),
                  CountryLinkFromISO3(456, 'AFG'),
                  CountryLinkFromISO2(789, 'AO')]

This appears to work at first glance, but then I run into a problem:

a = CountryLinkFromISO2(123, 'AW')
print(type(a))
print(a.country)
print(type(a.country))

returns:

<class '__main__.CountryLinkFromISO2'>
AW
<class 'str'>

The inherited object has the right type, but its attribute country is just a string instead of the Country type that I expect. I have put in print statements in the __new__ that check the type of new_obj.country, and it is correct before the return line.

What I want to achieve is to have a be an object of the type CountryLinkFromISO2 that will inherit changes I make to CountryLink and for it to have an attribute country that is taken from the dictionary countries_by_iso2. How can I achieve this?

like image 549
EdG Avatar asked Aug 08 '21 18:08

EdG


People also ask

When a dataclass inherits a normal class from a sub-class?

When a DataClass inherits a normal class, the __init__ () from the super-class is overidden in sub-class. author in GfgArticle overrides the same in Article – As the basic concept of inheritance, the value for its assignment is first looked in the sub-class and followed up the tree in super-class.

How do dataclasses behave when inherited?

In this post, we will discuss how DataClasses behave when inherited. Though they make their own constructors, DataClasses behave pretty much the same way as normal classes do when inherited. Few points from above code: Both SuperClass and SubClass are DataClasses – although super-class or sub-class being a normal class is also possible.

How to inherit a data class in Python?

Python Data Class Inheritance Inheriting a data class to another is the same as we do in Object-Oriented Programming. Just provide the parent class name to the child class parenthesis.

How does the @dataclass decorator create data classes?

This is documented in PEP-557 Dataclasses, under inheritance: When the Data Class is being created by the @dataclass decorator, it looks through all of the class's base classes in reverse MRO (that is, starting at object) and, for each Data Class that it finds, adds the fields from that base class to an ordered mapping of fields.


2 Answers

Just because the dataclass does it behind the scenes, doesn't mean you classes don't have an __init__(). They do and it looks like:

def __init__(self, person_id: int, country: Country):
    self.person_id = person_id
    self.country = country

When you create the class with:

CountryLinkFromISO2(123, 'AW')

that "AW" string gets passed to __init__() and sets the value to a string.

Using __new__() in this way is fragile and returning None from a constructor is fairly un-pythonic (imo). Maybe you would be better off making an actual factory function that returns either None or the class you want. Then you don't need to mess with __new__() at all.

@dataclass
class CountryLinkFromISO2(CountryLink):
    @classmethod
    def from_country_code(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        return cls(person_id, countries_by_iso2[iso2])

a = CountryLinkFromISO2.from_country_code(123, 'AW')

If for some reason it needs to work with __new__(), you could return None from new when there's no match, and set the country in __post_init__():

@dataclass
class CountryLinkFromISO2(CountryLink):
    def __new__(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        return super().__new__(cls)
    
    def __post_init__(self):        
        self.country = countries_by_iso2[self.country]
like image 124
Mark Avatar answered Oct 30 '22 23:10

Mark


The behaviour you see is because dataclasses set their fields in __init__, which happens after __new__ has run.

The Pythonic way to solve this would be to provide an alternate constructor. I would not do the subclasses, as they are only used for their constructor.

For example:

@dataclass
class CountryLink:
    person_id: int
    country: Country

    @classmethod
    def from_iso2(cls, person_id: int, country_code: str):
        try:
            return cls(person_id, countries_by_iso2[country_code])
        except KeyError:
            raise ValueError(f'invalid ISO2 country code {country_code!r}') from None

    @classmethod
    def from_iso3(cls, person_id: int, country_code: str):
        try:
            return cls(person_id, countries_by_iso3[country_code])
        except KeyError:
            raise ValueError(f'invalid ISO3 country code {country_code!r}') from None

country_links = [ CountryLink.from_iso2(123, 'AW'),
                  CountryLink.from_iso3(456, 'AFG'),
                  CountryLink.from_iso2(789, 'AO')]
like image 29
Jasmijn Avatar answered Oct 30 '22 23:10

Jasmijn