Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Explanation of the Python itertools.product() implementation?

How does this Python magic work?

The code in question comes from the Python itertools.product documentation:

def product(*args):
    pools = map(tuple, args)
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    for prod in result:
        yield tuple(prod)

Note: Yes, I understand this is not the actual implementation. Also, I have removed the repeat arg to simplify the code and focus the discussion

I understand what the above code does (its output), but I'm looking for an explanation of how it works. I'm already familiar with list comprehensions, including nested fors.

This is specifically puzzling:

for pool in pools:
    result = [x+[y] for x in result for y in pool]

I've attempted to convert the above code into a series of for loops without any list comprehension, but I fail to get correct results. This kind of algorithm generally requires recursion to handle an arbitrary number of input sets. Hence my confusion that the above code seems to do this iteratively.

Can anyone explain how this code works?

like image 304
djsmith Avatar asked Oct 18 '25 16:10

djsmith


1 Answers

Here is the list-comprehension converted into a regular for-loop, if that helps you understand:

def product_nocomp(*args):
    pools = map(tuple, args)
    result = [[]]
    for pool in pools:
        _temp = []
        for x in result:
            for y in pool:
                _temp.append(x + [y])
        result = _temp
    for prod in result:
        yield tuple(prod)

And here's some illuminating print's:

In [9]: def product_nocomp(*args):
   ...:     pools = list(map(tuple, args))
   ...:     result = [[]]
   ...:     print("Pools: ", pools)
   ...:     for pool in pools:
   ...:         print(result, end=' | ')
   ...:         _temp = []
   ...:         for x in result:
   ...:             for y in pool:
   ...:                 _temp.append(x + [y])
   ...:         result = _temp
   ...:         print(result)
   ...:     for prod in result:
   ...:         yield tuple(prod)
   ...:

In [10]: list(product_nocomp(range(2), range(2)))
Pools:  [(0, 1), (0, 1)]
[[]] | [[0], [1]]
[[0], [1]] | [[0, 0], [0, 1], [1, 0], [1, 1]]
Out[10]: [(0, 0), (0, 1), (1, 0), (1, 1)]

So, for every tuple in the pool, it goes through every sublist in the intermediate result, and adds every item to the sublist for every item in the current pool. Note, it is creating new lists.

like image 105
juanpa.arrivillaga Avatar answered Oct 20 '25 06:10

juanpa.arrivillaga