Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Flatten nested dictionaries, compressing keys

People also ask

How do you flatten nested dictionary?

Basically the same way you would flatten a nested list, you just have to do the extra work for iterating the dict by key/value, creating new keys for your new dictionary and creating the dictionary at final step.

How do you break nested dictionaries in Python?

In Python, we use “ del “ statement to delete elements from nested dictionary.

Are nested dictionaries bad?

There is nothing inherently wrong with nested dicts. Anything can be a dict value, and it can make sense for a dict to be one. A lot of the time when people make nested dicts, their problems could be solved slightly more easily by using a dict with tuples for keys.

How do you flatten a nested list in Python?

Provide two arguments to the sum() method: my_list and an empty list (i.e. [ ] ). sum() combines my_list and [ ] to produce a flattened list.


Basically the same way you would flatten a nested list, you just have to do the extra work for iterating the dict by key/value, creating new keys for your new dictionary and creating the dictionary at final step.

import collections

def flatten(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

>>> flatten({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
{'a': 1, 'c_a': 2, 'c_b_x': 5, 'd': [1, 2, 3], 'c_b_y': 10}

Or if you are already using pandas, You can do it with json_normalize() like so:

import pandas as pd

d = {'a': 1,
     'c': {'a': 2, 'b': {'x': 5, 'y' : 10}},
     'd': [1, 2, 3]}

df = pd.json_normalize(d, sep='_')

print(df.to_dict(orient='records')[0])

Output:

{'a': 1, 'c_a': 2, 'c_b_x': 5, 'c_b_y': 10, 'd': [1, 2, 3]}

There are two big considerations that the original poster needs to consider:

  1. Are there keyspace clobbering issues? For example, {'a_b':{'c':1}, 'a':{'b_c':2}} would result in {'a_b_c':???}. The below solution evades the problem by returning an iterable of pairs.
  2. If performance is an issue, does the key-reducer function (which I hereby refer to as 'join') require access to the entire key-path, or can it just do O(1) work at every node in the tree? If you want to be able to say joinedKey = '_'.join(*keys), that will cost you O(N^2) running time. However if you're willing to say nextKey = previousKey+'_'+thisKey, that gets you O(N) time. The solution below lets you do both (since you could merely concatenate all the keys, then postprocess them).

(Performance is not likely an issue, but I'll elaborate on the second point in case anyone else cares: In implementing this, there are numerous dangerous choices. If you do this recursively and yield and re-yield, or anything equivalent which touches nodes more than once (which is quite easy to accidentally do), you are doing potentially O(N^2) work rather than O(N). This is because maybe you are calculating a key a then a_1 then a_1_i..., and then calculating a then a_1 then a_1_ii..., but really you shouldn't have to calculate a_1 again. Even if you aren't recalculating it, re-yielding it (a 'level-by-level' approach) is just as bad. A good example is to think about the performance on {1:{1:{1:{1:...(N times)...{1:SOME_LARGE_DICTIONARY_OF_SIZE_N}...}}}})

Below is a function I wrote flattenDict(d, join=..., lift=...) which can be adapted to many purposes and can do what you want. Sadly it is fairly hard to make a lazy version of this function without incurring the above performance penalties (many python builtins like chain.from_iterable aren't actually efficient, which I only realized after extensive testing of three different versions of this code before settling on this one).

from collections import Mapping
from itertools import chain
from operator import add

_FLAG_FIRST = object()

def flattenDict(d, join=add, lift=lambda x:(x,)):
    results = []
    def visit(subdict, results, partialKey):
        for k,v in subdict.items():
            newKey = lift(k) if partialKey==_FLAG_FIRST else join(partialKey,lift(k))
            if isinstance(v,Mapping):
                visit(v, results, newKey)
            else:
                results.append((newKey,v))
    visit(d, results, _FLAG_FIRST)
    return results

To better understand what's going on, below is a diagram for those unfamiliar with reduce(left), otherwise known as "fold left". Sometimes it is drawn with an initial value in place of k0 (not part of the list, passed into the function). Here, J is our join function. We preprocess each kn with lift(k).

               [k0,k1,...,kN].foldleft(J)
                           /    \
                         ...    kN
                         /
       J(k0,J(k1,J(k2,k3)))
                       /  \
                      /    \
           J(J(k0,k1),k2)   k3
                    /   \
                   /     \
             J(k0,k1)    k2
                 /  \
                /    \
               k0     k1

This is in fact the same as functools.reduce, but where our function does this to all key-paths of the tree.

>>> reduce(lambda a,b:(a,b), range(5))
((((0, 1), 2), 3), 4)

Demonstration (which I'd otherwise put in docstring):

>>> testData = {
        'a':1,
        'b':2,
        'c':{
            'aa':11,
            'bb':22,
            'cc':{
                'aaa':111
            }
        }
    }
from pprint import pprint as pp

>>> pp(dict( flattenDict(testData) ))
{('a',): 1,
 ('b',): 2,
 ('c', 'aa'): 11,
 ('c', 'bb'): 22,
 ('c', 'cc', 'aaa'): 111}

>>> pp(dict( flattenDict(testData, join=lambda a,b:a+'_'+b, lift=lambda x:x) ))
{'a': 1, 'b': 2, 'c_aa': 11, 'c_bb': 22, 'c_cc_aaa': 111}    

>>> pp(dict( (v,k) for k,v in flattenDict(testData, lift=hash, join=lambda a,b:hash((a,b))) ))
{1: 12416037344,
 2: 12544037731,
 11: 5470935132935744593,
 22: 4885734186131977315,
 111: 3461911260025554326}

Performance:

from functools import reduce
def makeEvilDict(n):
    return reduce(lambda acc,x:{x:acc}, [{i:0 for i in range(n)}]+range(n))

import timeit
def time(runnable):
    t0 = timeit.default_timer()
    _ = runnable()
    t1 = timeit.default_timer()
    print('took {:.2f} seconds'.format(t1-t0))

>>> pp(makeEvilDict(8))
{7: {6: {5: {4: {3: {2: {1: {0: {0: 0,
                                 1: 0,
                                 2: 0,
                                 3: 0,
                                 4: 0,
                                 5: 0,
                                 6: 0,
                                 7: 0}}}}}}}}}

import sys
sys.setrecursionlimit(1000000)

forget = lambda a,b:''

>>> time(lambda: dict(flattenDict(makeEvilDict(10000), join=forget)) )
took 0.10 seconds
>>> time(lambda: dict(flattenDict(makeEvilDict(100000), join=forget)) )
[1]    12569 segmentation fault  python

... sigh, don't think that one is my fault...


[unimportant historical note due to moderation issues]

Regarding the alleged duplicate of Flatten a dictionary of dictionaries (2 levels deep) of lists

That question's solution can be implemented in terms of this one by doing sorted( sum(flatten(...),[]) ). The reverse is not possible: while it is true that the values of flatten(...) can be recovered from the alleged duplicate by mapping a higher-order accumulator, one cannot recover the keys. (edit: Also it turns out that the alleged duplicate owner's question is completely different, in that it only deals with dictionaries exactly 2-level deep, though one of the answers on that page gives a general solution.)


If you're using pandas there is a function hidden in pandas.io.json._normalize1 called nested_to_record which does this exactly.

from pandas.io.json._normalize import nested_to_record    

flat = nested_to_record(my_dict, sep='_')

1 In pandas versions 0.24.x and older use pandas.io.json.normalize (without the _)


Here is a kind of a "functional", "one-liner" implementation. It is recursive, and based on a conditional expression and a dict comprehension.

def flatten_dict(dd, separator='_', prefix=''):
    return { prefix + separator + k if prefix else k : v
             for kk, vv in dd.items()
             for k, v in flatten_dict(vv, separator, kk).items()
             } if isinstance(dd, dict) else { prefix : dd }

Test:

In [2]: flatten_dict({'abc':123, 'hgf':{'gh':432, 'yu':433}, 'gfd':902, 'xzxzxz':{"432":{'0b0b0b':231}, "43234":1321}}, '.')
Out[2]: 
{'abc': 123,
 'gfd': 902,
 'hgf.gh': 432,
 'hgf.yu': 433,
 'xzxzxz.432.0b0b0b': 231,
 'xzxzxz.43234': 1321}