Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

flatten a nested list with indices in python

Tags:

I have a list ['','','',['',[['a','b']['c']]],[[['a','b'],['c']]],[[['d']]]]

I want to flatten the list with indices and the output should be as follows:

flat list=['','','','','a','b','c','a','b','c','d']
indices=[0,1,2,3,3,3,3,4,4,4,5]

How to do this?

I have tried this:

def flat(nums):
    res = []
    index = []
    for i in range(len(nums)):
        if isinstance(nums[i], list):
            res.extend(nums[i])
            index.extend([i]*len(nums[i]))
        else:
            res.append(nums[i])
            index.append(i)
    return res,index

But this doesn't work as expected.

like image 285
Miffy Avatar asked Mar 11 '19 05:03

Miffy


People also ask

How do I flatten a multiple nested list?

To flatten a nested list we can use the for loop or the While loop or recursion or the deepflatten() Method. Other techniques include importing numerous external Python libraries and using their built-in functions.

How do I merge a nested list into one list in Python?

Method 1: Using nested for loop Create a variable to store the input list of lists(nested list). Use the append() function(adds the element to the list at the end) to add this element to the result list. Printing the resultant list after joining the input list of lists.

How do I flatten a list of lists in Python?

To flatten a list of lists in Python, use the numpy library, concatenate(), and flat() function. Numpy offers common operations, including concatenating regular 2D arrays row-wise or column-wise. We also use the flat attribute to get a 1D iterator over the array to achieve our goal.


1 Answers

TL;DR

This implementation handles nested iterables with unbounded depth:

def enumerate_items_from(iterable):
    cursor_stack = [iter(iterable)]
    item_index = -1
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:
            cursor_stack.pop()
            continue
        if len(cursor_stack) == 1:
            item_index += 1
        if not isinstance(item, str):
            try:
                cursor_stack.append(iter(item))
                continue
            except TypeError:
                pass
        yield item, item_index

def flat(iterable):
    return map(list, zip(*enumerate_items_from(a)))

Which can be used to produce the desired output:


>>> nested = ['', '', '', ['', [['a', 'b'], ['c']]], [[['a', 'b'], ['c']]], [[['d']]]]
>>> flat_list, item_indexes = flat(nested)
>>> print(item_indexes)
[0, 1, 2, 3, 3, 3, 3, 4, 4, 4, 5]
>>> print(flat_list)
['', '', '', '', 'a', 'b', 'c', 'a', 'b', 'c', 'd']

Note that you should probably put the index first to mimic what enumerate does. It would be easier to use for people that already know enumerate.

Important remark unless you are certain your lists will not be nested too much, you shouldn't use any recursion-based solution. Otherwise as soon as you'll have a nested list with depth greater than 1000, your code will crash. I explain this here. Note that a simple call to str(list) will crash on a test case with depth > 1000 (for some python implementations it's more than that, but it's always bounded). The typical exception you'll have when using recursion-based solutions is (this in short is due to how python call stack works):

RecursionError: maximum recursion depth exceeded ... 

Implementation details

I'll go step by step, first we will flatten a list, then we will output both the flattened list and the depth of all items, and finally we will output both the list and the corresponding item indexes in the "main list".

Flattening list

That being said, this is actually quite interesting as the iterative solution is perfectly designed for that, you can take a simple (non-recursive) list flattening algorithm:

def flatten(iterable):
    return list(items_from(iterable))

def items_from(iterable):
    cursor_stack = [iter(iterable)]
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:       # post-order
            cursor_stack.pop()
            continue
        if isinstance(item, list):  # pre-order
            cursor_stack.append(iter(item))
        else:
            yield item              # in-order

Computing depth

We can have access to the depth by looking at the stack size, depth = len(cursor_stack) - 1

        else:
            yield item, len(cursor_stack) - 1      # in-order

This will return an iterative on pairs (item, depth), if we need to separate this result in two iterators we can use the zip function:

>>> a = [1,  2,  3, [4 , [[5, 6], [7]]], [[[8, 9], [10]]], [[[11]]]]
>>> flatten(a)
[(1, 0), (2, 0), (3, 0), (4, 1), (5, 3), (6, 3), (7, 3), (8, 3), (9, 3), (10, 3), (11, 3)]
>>> flat_list, depths = zip(*flatten(a))
>>> print(flat_list)
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)
>>> print(depths)
(0, 0, 0, 1, 3, 3, 3, 3, 3, 3, 3)

We will now do something similar to have item indexes instead of the depth.

Computing item indexes

To instead compute item indexes (in the main list), you'll need to count the number of items you've seen so far, which can be done by adding 1 to an item_index every time we iterate over an item that is at depth 0 (when the stack size is equal to 1):

def flatten(iterable):
    return list(items_from(iterable))

def items_from(iterable):
    cursor_stack = [iter(iterable)]
    item_index = -1
    while cursor_stack:
        sub_iterable = cursor_stack[-1]
        try:
            item = next(sub_iterable)
        except StopIteration:             # post-order
            cursor_stack.pop()
            continue
        if len(cursor_stack) == 1:        # If current item is in "main" list
            item_index += 1               
        if isinstance(item, list):        # pre-order
            cursor_stack.append(iter(item))
        else:
            yield item, item_index        # in-order

Similarly we will break pairs in two itératifs using ˋzip, we will also use ˋmap to transform both iterators to lists:

>>> a = [1,  2,  3, [4 , [[5, 6], [7]]], [[[8, 9], [10]]], [[[11]]]]
>>> flat_list, item_indexes = map(list, zip(*flatten(a)))
>>> print(flat_list)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
>>> print(item_indexes)
[0, 1, 2, 3, 3, 3, 3, 4, 4, 4, 5]

improvement — Handling iterable inputs

Being able to take a broader palette of nested iterables as input could be desirable (especially if you build this for others to use). For example, the current implementation doesn't work as expected if we have nested iterables as input, for example:

>>> a = iter([1, '2',  3, iter([4, [[5, 6], [7]]])])
>>> flat_list, item_indexes = map(list, zip(*flatten(a)))
>>> print(flat_list)
[1, '2', 3, <list_iterator object at 0x100f6a390>]
>>> print(item_indexes)
[0, 1, 2, 3]

If we want this to work we need to be a bit careful because strings are iterable but we want them to be considered as atomic items (not a as lists of characters). Instead of assuming the input is a list as we did before:

        if isinstance(item, list):        # pre-order
            cursor_stack.append(iter(item))
        else:
            yield item, item_index        # in-order

We will not inspect the input type, instead we will try to use it as if it was an iterable and if it fails we will know that it’s not an iterable (duck typing):

       if not isinstance(item, str):
            try:
                cursor_stack.append(iter(item))
                continue
            # item is not an iterable object:
            except TypeError:
                pass
        yield item, item_index

With this implementation, we have:

>>> a = iter([1, 2,  3, iter([4, [[5, 6], [7]]])])
>>> flat_list, item_indexes = map(list, zip(*flatten(a)))
>>> print(flat_list)
[1, 2, 3, 4, 5, 6, 7]
>>> print(item_indexes)
[0, 1, 2, 3, 3, 3, 3]

Building test cases

If you need to generate tests cases with large depths, you can use this piece of code:

def build_deep_list(depth):
    """Returns a list of the form $l_{depth} = [depth-1, l_{depth-1}]$
    with $depth > 1$ and $l_0 = [0]$.
    """
    sub_list = [0]
    for d in range(1, depth):
        sub_list = [d, sub_list]
    return sub_list

You can use this to make sure my implementation doesn't crash when the depth is large:

a = build_deep_list(1200)
flat_list, item_indexes = map(list, zip(*flatten(a)))

We can also check that we can't print such a list by using the str function:

>>> a = build_deep_list(1200)
>>> str(a)
RecursionError: maximum recursion depth exceeded while getting the repr of an object

Function repr is called by str(list) on every element from the input list.

Concluding remarks

In the end I agree that recursive implementations are way easier to read (as the call stack does half the hard work for us), but when implementing low level function like that I think it is a good investment to have a code that works in all cases (or at least all the cases you can think of). Especially when the solution is not that hard. That's also a way not to forget how to write non-recursive code working on tree-like structures (which may not happen a lot unless you are implementing data structures yourself, but that's a good exercise).

Note that everything I say “against” recursion is only true because python doesn't optimize call stack usage when facing recursion: Tail Recursion Elimination in Python. Whereas many compiled languages do Tail Call recursion Optimization (TCO). Which means that even if you write the perfect tail-recursive function in python, it will crash on deeply nested lists.

If you need more details on the list flattening algorithm you can refer to the post I linked.

like image 52
cglacet Avatar answered Oct 12 '22 01:10

cglacet