Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Better way to iterate over python dataclass keys and values?

I have two dataclasses, Route and Factors. Route contains a value and three copies of Factors.

Route does not know how many variables Factors contains. I want to get the name of these variables, and then get the respective value of each one, for each copy of Factors.

Here is what I currently have:

@dataclass
class Factors:
    do: bool  # does it do the route
    hub: int # how many of the locations are hubs

    def __init__(self, do_init):
        self.do = do_init
        self.hub = 0 # will add later 

    def __str__(self):
        return "%s" % self.do


@dataclass
class Route:
    route: tuple
    skyteam: Factors
    star: Factors
    oneworld: Factors

    def __init__(self, route):
        self.route = route.get('route')
        # this could probably be done with one line loop and a variable with names
        self.skyteam = Factors(route.get('skyteam'))
        self.star = Factors(route.get('star'))
        self.oneworld = Factors(route.get('oneworld'))

    def __str__(self):
        table = [[self.route, "SkyTeam", "StarAlliance", "OneWorld"]] # var name is fine
        for var in Factors.__dict__.get('__annotations__').keys():  # for each factor
            factor = [var]
            factor.append(self.skyteam.__dict__.get(var))
            factor.append(self.star.__dict__.get(var))
            factor.append(self.oneworld.__dict__.get(var))
            table.append(factor)
        return tabulate.tabulate(table, tablefmt='plain')

Input is

{'route': ('BOS', 'DXB'), 'skyteam': True, 'star': True, 'oneworld': True}

Current output is

('BOS', 'DXB')  SkyTeam  StarAlliance  OneWorld
do              True     True          True
hub             0        0             0

Maybe I could search Route for each variable that contains a Factors datatype and iterate over those?

like image 698
Adin D Avatar asked Mar 16 '20 21:03

Adin D


3 Answers

You may use dataclass.fields

from dataclasses import dataclass, fields

for field in fields(YourDataclass):
    print(field.name, getattr(YourDataclass, field.name))
like image 131
Andrey Topoleov Avatar answered Nov 06 '22 09:11

Andrey Topoleov


I'd also make use of the __dataclass_fields__ which returns a dict of the variable names and their types. For example something like:

for field in mydataclass.__dataclass_fields__:
    value = getattr(mydataclass, field)
    print(field, value)
like image 44
Aziz Alto Avatar answered Nov 06 '22 08:11

Aziz Alto


I'd leave the builtin __str__s alone and just call the function visualize or something on the Route class, but that's taste. You also shouldn't overload the __init__ of a dataclass unless you absolutely have to, just splat your input dict into the default constructor.

Final nit, try to use getattr/setattr over accessing the __dict__, dataclasses are popular for using __slots__ to store their attributes, which would break your code in a non-trivial way.

So I'd go with something like this, using the tabulate library to handle the rendering:

from dataclasses import dataclass, fields
import tabulate

@dataclass
class Factor:
    do: bool
    hub: int = 0 # this is how to add a default value


@dataclass
class Route:
    route: tuple
    skyteam: Factor
    star: Factor
    oneworld: Factor

    def __post_init__(self):
        # turn Factor parameter dicts into Factor objects
        for field in fields(self):
            if issubclass(field.type, Factor):
                setattr(self, field.name, field.type(getattr(self, field.name)))

    def visualize(self):
        factors = {
            field.name: getattr(self, field.name)
            for field in fields(self)
            if issubclass(field.type, Factor)
        }
        rows = [[self.route, *factors]]  # header
        for field in fields(Factor):
            rows.append([field.name, *[getattr(f, field.name) for f in factors.values()]])
        print(tabulate.tabulate(rows))

Which works out fine for your example:

>>> r = Route(**{'route': ('BOS', 'DXB'), 'skyteam': True, 'star': True, 'oneworld': True})
>>> r.visualize()
--------------  -------  ----  --------
('BOS', 'DXB')  skyteam  star  oneworld
do              True     True  True
hub             0        0     0
--------------  -------  ----  --------

This solution should continue to work both if you add more fields to the Factor class and more factor instances to the Route.

like image 45
Arne Avatar answered Nov 06 '22 10:11

Arne