Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to define a dataclass so each of its attributes is the list of its subclass attributes?

I have this code:

from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)

How can I create additional attributes in the dataclass Section so they would be a list of the attribute of its subclass Position?

In my example, I would like that the section object also returns:

sec.name = ['a', 'b', 'c']   #[pos1.name,pos2.name,pos3.name]
sec.lon = [52, 46, 45]       #[pos1.lon,pos2.lon,pos3.lon]
sec.lat = [10, -10, -10]     #[pos1.lat,pos2.lat,pos3.lat]

I tried to define the dataclass as:

@dataclass
class Section:
    positions: List[Position]
    names :  List[Position.name]

But it is not working because name is not an attribute of position. I can define the object attributed later in the code (e.g. by doing secs.name = [x.name for x in section.positions]). But it would be nicer if it can be done at the dataclass definition level.

After posting this question I found a beginning of answer (https://stackoverflow.com/a/65222586/13890678).

But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?

like image 888
lhoupert Avatar asked Dec 09 '20 11:12

lhoupert


People also ask

How does Dataclass work in python?

DataClass in Python DataClasses are like normal classes in Python, but they have some basic functions like instantiation, comparing, and printing the classes already implemented. Parameters: init: If true __init__() method will be generated. repr: If true __repr__() method will be generated.

Can Dataclass have methods?

A dataclass can very well have regular instance and class methods. Dataclasses were introduced from Python version 3.7. For Python versions below 3.7, it has to be installed as a library.

What is __ Post_init __ python?

Modifying fields after initialization with __post_init__ The __post_init__ method is called just after initialization. In other words, it is called after the object receives values for its fields, such as name , continent , population , and official_lang .

What is a @dataclass?

This class is meant to be a simplified representation of an HTTP response. To create a data class all we need to do is use the @dataclass decorator on a custom class like this: The previous example creates a Response class with a status and body attributes. The @dataclass decorator by default gives us these benefits:

What is data class in a data set?

A data class is a list of data set allocation attributes and their values. You cannot assign a data class to an object; however, data class may be used for allocation of a scratch tape to be used to write objects.

How to define class attributes in Python data class?

In Python a class can have a class attribute, the difference from instance attributes are mainly these two: Every instance of the class will share the same value of a class attribute. We can define class attributes in a data class by using the pseudo-field typing.ClassVar

How to define default values for an attribute in a dataclass?

Code language:Python(python) To define a default value for an attribute in the dataclass, you assign it to the attribute like this: fromdataclasses importdataclass @dataclassclassPerson:name: str age: int iq: int = 100print(Person('John Doe', 25))


3 Answers

You could create a new field after __init__ was called:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]
    _pos: dict = field(init=False, repr=False)

    def __post_init__(self):
        # create _pos after init is done, read only!
        Section._pos = property(Section._get_positions)

    def _get_positions(self):
        _pos = {}

        # iterate over all fields and add to _pos
        for field in [f.name for f in fields(self.positions[0])]:
            if field not in _pos:
                _pos[field] = []

            for p in self.positions:
                _pos[field].append(getattr(p, field))
        return _pos


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])

Out:

[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]

Edit:

In case you just need it more generic, you could overwrite __getattr__:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]

    def __getattr__(self, keyName):
        for f in fields(self.positions[0]):
            if f"{f.name}s" == keyName:
                return [getattr(x, f.name) for x in self.positions]
        # Error handling here: Return empty list, raise AttributeError, ...

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.names)
print(sec.lons)
print(sec.lats)

Out:

['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
like image 122
Maurice Meyer Avatar answered Oct 19 '22 23:10

Maurice Meyer


After some more thinking I thought an alternative solution using methods:


from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

    def names(self):
        return [x.name for x in self.positions]

    def lons(self):
        return [x.lon for x in self.positions]

    def lats(self):
        return [x.lat for x in self.positions]


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)
print(sec.names())
print(sec.lons())
print(sec.lats())

But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?

like image 39
lhoupert Avatar answered Oct 20 '22 01:10

lhoupert


The way I understood you, you'd like to declare dataclasses that are flat data containers (like Position), which are nested into a container of another dataclass (like Section). The outer dataclass should then be able to access a list of all the attributes of its inner dataclass(es) through simple name access.

We can implement this kind of functionality (calling it, for example, introspect) on top of how a regular dataclass works, and can enable it on demand, similar to the already existing flags:

from dataclasses import is_dataclass, fields, dataclass as dc

# existing dataclass siganture, plus "instrospection" keyword
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
              unsafe_hash=False, frozen=False, introspect=False):

    def wrap(cls):
        # run original dataclass decorator
        dc(cls, init=init, repr=repr, eq=eq, order=order,
           unsafe_hash=unsafe_hash, frozen=frozen)

        # add our custom "introspect" logic on top
        if introspect:
            for field in fields(cls):
                # only consider nested dataclass in containers
                try:
                    name = field.type._name
                except AttributeError:
                    continue
                if name not in ("List", "Set", "Tuple"):
                    continue
                contained_dc = field.type.__args__[0]
                if not is_dataclass(contained_dc):
                    continue
                # once we got them, add their fields as properties
                for dc_field in fields(contained_dc):
                    # if there are name-clashes, use f"{field.name}_{dc_field.name}" instead
                    property_name = dc_field.name
                    # bind variables explicitly to avoid funny bugs
                    def magic_property(self, field=field, dc_field=dc_field):
                        return [getattr(attr, dc_field.name) for attr in getattr(self, field.name)]
                    # here is where the magic happens
                    setattr(
                        cls,
                        property_name,
                        property(magic_property)
                    )
        return cls

    # Handle being called with or without parens
    if _cls is None:
        return wrap
    return wrap(_cls)

The resulting dataclass-function can now be used in the following way:

# regular dataclass
@dataclass
class Position:
    name: str
    lon: float
    lat: float
    
# this one will introspect its fields and try to add magic properties
@dataclass(introspect=True)
class Section:
    positions: List[Position]

And that's it. The properties get added during class construction, and will even update accordingly if any of the objects changes during its lifetime:

>>> p_1 = Position("1", 1.0, 1.1)
>>> p_2 = Position("2", 2.0, 2.1)
>>> p_3 = Position("3", 3.0, 3.1)
>>> section = Section([p_1 , p_2, p_3])
>>> section.name
['1', '2', '3']
>>> section.lon
[1.0, 2.0, 3.0]
>>> p_1.lon = 5.0
>>> section.lon
[5.0, 2.0, 3.0]
like image 1
Arne Avatar answered Oct 19 '22 23:10

Arne