Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generating All Subsets of a Set Using Recursive Backtracking (Python)

I'm trying to understand backtracking but I'm stuck in this problem, here's the prompt:

Given a set of distinct integers, return all possible subsets.

Example input: [1,2,3]

Example output: [[], [1], [2], [3], [1, 2], [1, 3], [2, 3], [1, 2, 3]]

Here's my code:

def subsets(nums):
    res = []
    backtrack(res, [], nums, 0)
    return res

def backtrack(res, temp, nums, start):
    # print(temp)
    res.append(temp)
    for i in range(start, len(nums)):
        temp.append(nums[i])
        backtrack(res, temp, nums, i + 1)
        temp.pop() # Backtrack

when I return res I get a list of empty lists of size 2^(len(nums)), which is the right size but the numbers aren't there. However printing temp before I do res.append(temp) shows that temp is carrying the right output.

E.g.

res = [[], [], [], [], [], [], [], []]

print statements:

[] [1] [1, 2] [1, 2, 3] [1, 3] [2] [2, 3] [3]

Why are the changes not carrying over to the res list?

Edit 1:

This solution works, what's the difference?

def subsets(nums):
    res = []
    backtrack(res, [], nums, 0)
    return res

def backtrack(res, temp, nums, start):
    # print(temp)
    res.append(temp)
    for i in range(start, len(nums)):
        backtrack(res, temp + [nums[i]], nums, i + 1)
like image 298
YSA Avatar asked Dec 18 '22 05:12

YSA


1 Answers

You're appending multiple references to the same list object to res. We can demonstrate this by doing

result = subsets([1, 2, 3])
print([id(u) for u in result])

That will print a list of 8 identical IDs.

So the various changes that you make to temp get "lost", and the final contents of res will be 8 references to whatever the final value of temp is, and in this case it's the empty list.


The simple way to fix this is to append copies of temp to res.

def subsets(nums):
    res = []
    backtrack(res, [], nums, 0)
    return res

def backtrack(res, temp, nums, start):
    res.append(temp[:])
    for i in range(start, len(nums)):
        temp.append(nums[i])
        backtrack(res, temp, nums, i + 1)
        temp.pop() # Backtrack

print(subsets([1, 2, 3]))

output

[[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]

FWIW, I realise that the main point of this exercise is to practice recursion, but in Python it's better to avoid recursion unless you really need it (eg, for processing recursive data structures like trees). But here's a more compact iterative solution.

def subsets(seq):
    z = [[]]
    for x in seq:
        z += [y + [x] for y in z]
    return z

To see how this works, we can expand it a little, and add a print call.

def subsets(seq):
    z = [[]]
    for x in seq:
        print('z =', z, 'x =', x)
        w = []
        for y in z:
            w += [y + [x]]
        z += w
    return z

result = subsets([1, 2, 3])
print(result)  

output

z = [[]] x = 1
z = [[], [1]] x = 2
z = [[], [1], [2], [1, 2]] x = 3
[[], [1], [2], [1, 2], [3], [1, 3], [2, 3], [1, 2, 3]]

We start with list z containing a single empty list.

On each loop we create a new list w by looping over z and making each item in w a copy of the corresponding item in z with the current x appended to it. We then extend z with the contents of w.


Just for fun, here's an iterative generator that produces subsets (in natural order) from bitstrings. This method is actually quite efficient, and it's good if you want all the subsets of a large sequence without consuming a lot of RAM.

def subsets(seq):
    w = len(seq)
    for i in range(1<<w):
        yield [u for u, v in zip(seq, reversed('{:0{}b}'.format(i, w))) if v=='1']

print(*subsets([1, 2, 3]))

output

[] [1] [2] [1, 2] [3] [1, 3] [2, 3] [1, 2, 3]
like image 167
PM 2Ring Avatar answered Dec 20 '22 19:12

PM 2Ring