I'm trying to understand backtracking but I'm stuck in this problem, here's the prompt:
Given a set of distinct integers, return all possible subsets.
Example input: [1,2,3]
Example output: [[], [1], [2], [3], [1, 2], [1, 3], [2, 3], [1, 2, 3]]
Here's my code:
def subsets(nums):
res = []
backtrack(res, [], nums, 0)
return res
def backtrack(res, temp, nums, start):
# print(temp)
res.append(temp)
for i in range(start, len(nums)):
temp.append(nums[i])
backtrack(res, temp, nums, i + 1)
temp.pop() # Backtrack
when I return res
I get a list of empty lists of size 2^(len(nums))
, which is the right size but the numbers aren't there. However printing temp
before I do res.append(temp)
shows that temp is carrying the right output.
E.g.
res = [[], [], [], [], [], [], [], []]
print statements:
[] [1] [1, 2] [1, 2, 3] [1, 3] [2] [2, 3] [3]
Why are the changes not carrying over to the res
list?
Edit 1:
This solution works, what's the difference?
def subsets(nums):
res = []
backtrack(res, [], nums, 0)
return res
def backtrack(res, temp, nums, start):
# print(temp)
res.append(temp)
for i in range(start, len(nums)):
backtrack(res, temp + [nums[i]], nums, i + 1)
You're appending multiple references to the same list object to res
. We can demonstrate this by doing
result = subsets([1, 2, 3])
print([id(u) for u in result])
That will print a list of 8 identical IDs.
So the various changes that you make to temp
get "lost", and the final contents of res
will be 8 references to whatever the final value of temp
is, and in this case it's the empty list.
The simple way to fix this is to append copies of temp
to res
.
def subsets(nums):
res = []
backtrack(res, [], nums, 0)
return res
def backtrack(res, temp, nums, start):
res.append(temp[:])
for i in range(start, len(nums)):
temp.append(nums[i])
backtrack(res, temp, nums, i + 1)
temp.pop() # Backtrack
print(subsets([1, 2, 3]))
output
[[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]
FWIW, I realise that the main point of this exercise is to practice recursion, but in Python it's better to avoid recursion unless you really need it (eg, for processing recursive data structures like trees). But here's a more compact iterative solution.
def subsets(seq):
z = [[]]
for x in seq:
z += [y + [x] for y in z]
return z
To see how this works, we can expand it a little, and add a print
call.
def subsets(seq):
z = [[]]
for x in seq:
print('z =', z, 'x =', x)
w = []
for y in z:
w += [y + [x]]
z += w
return z
result = subsets([1, 2, 3])
print(result)
output
z = [[]] x = 1
z = [[], [1]] x = 2
z = [[], [1], [2], [1, 2]] x = 3
[[], [1], [2], [1, 2], [3], [1, 3], [2, 3], [1, 2, 3]]
We start with list z
containing a single empty list.
On each loop we create a new list w
by looping over z
and making each item in w
a copy of the corresponding item in z
with the current x
appended to it. We then extend z
with the contents of w
.
Just for fun, here's an iterative generator that produces subsets (in natural order) from bitstrings. This method is actually quite efficient, and it's good if you want all the subsets of a large sequence without consuming a lot of RAM.
def subsets(seq):
w = len(seq)
for i in range(1<<w):
yield [u for u, v in zip(seq, reversed('{:0{}b}'.format(i, w))) if v=='1']
print(*subsets([1, 2, 3]))
output
[] [1] [2] [1, 2] [3] [1, 3] [2, 3] [1, 2, 3]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With