Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to efficiently get all combinations where the sum is 10 or below in Python

Imagine you're trying to allocate some fixed resources (e.g. n=10) over some number of territories (e.g. t=5). I am trying to find out efficiently how to get all the combinations where the sum is n or below.

E.g. 10,0,0,0,0 is good, as well as 0,0,5,5,0 etc., while 3,3,3,3,3,3 is obviously wrong.

I got this far:

import itertools
t = 5
n = 10
r = [range(n+1)] * t
for x in itertools.product(*r): 
   if sum(x) <= n:          
       print x

This brute force approach is incredibly slow though; there must be a better way?

Timings (1000 iterations):

Default (itertools.product)           --- time: 40.90 s
falsetru recursion                    --- time:  3.63 s
Aaron Williams Algorithm (impl, Tony) --- time:  0.37 s
like image 808
PascalVKooten Avatar asked Mar 18 '23 13:03

PascalVKooten


2 Answers

Possible approach follows. Definitely would use with caution (hardly tested at all, but the results on n=10 and t=5 look reasonable).

The approach involves no recursion. The algorithm to generate partitions of a number n (10 in your example) having m elements (5 in your example) comes from Knuth's 4th volume. Each partition is then zero-extended if necessary, and all the distinct permutations are generated using an algorithm from Aaron Williams which I have seen referred to elsewhere. Both algorithms had to be translated to Python, and that increases the chance that errors have crept in. The Williams algorithm wanted a linked list, which I had to fake with a 2D array to avoid writing a linked-list class.

There goes an afternoon!

Code (note your n is my maxn and your t is my p):

import itertools

def visit(a, m):
    """ Utility function to add partition to the list"""
    x.append(a[1:m+1])

def parts(a, n, m):
    """ Knuth Algorithm H, Combinatorial Algorithms, Pre-Fascicle 3B
        Finds all partitions of n having exactly m elements.
        An upper bound on running time is (3 x number of
        partitions found) + m.  Not recursive!      
    """
    while (1):
        visit(a, m)
        while a[2] < a[1]-1:
            a[1] -= 1
            a[2] += 1
            visit(a, m)
        j=3
        s = a[1]+a[2]-1
        while a[j] >= a[1]-1:
            s += a[j]
            j += 1
        if j > m:
            break
        x = a[j] + 1
        a[j] = x
        j -= 1
        while j>1:
            a[j] = x
            s -= x
            j -= 1
            a[1] = s

def distinct_perms(partition):
    """ Aaron Williams Algorithm 1, "Loopless Generation of Multiset
        Permutations by Prefix Shifts".  Finds all distinct permutations
        of a list with repeated items.  I don't follow the paper all that
        well, but it _possibly_ has a running time which is proportional
        to the number of permutations (with 3 shift operations for each  
        permutation on average).  Not recursive!
    """

    perms = []
    val = 0
    nxt = 1
    l1 = [[partition[i],i+1] for i in range(len(partition))]
    l1[-1][nxt] = None
    #print(l1)
    head = 0
    i = len(l1)-2
    afteri = i+1
    tmp = []
    tmp += [l1[head][val]]
    c = head
    while l1[c][nxt] != None:
        tmp += [l1[l1[c][nxt]][val]]
        c = l1[c][nxt]
    perms.extend([tmp])
    while (l1[afteri][nxt] != None) or (l1[afteri][val] < l1[head][val]):
        if (l1[afteri][nxt] != None) and (l1[i][val]>=l1[l1[afteri][nxt]][val]):
            beforek = afteri
        else:
            beforek = i
        k = l1[beforek][nxt]
        l1[beforek][nxt] = l1[k][nxt]
        l1[k][nxt] = head
        if l1[k][val] < l1[head][val]:
            i = k
        afteri = l1[i][nxt]
        head = k
        tmp = []
        tmp += [l1[head][val]]
        c = head
        while l1[c][nxt] != None:
            tmp += [l1[l1[c][nxt]][val]]
            c = l1[c][nxt]
        perms.extend([tmp])

    return perms

maxn = 10 # max integer to find partitions of
p = 5  # max number of items in each partition

# Find all partitions of length p or less adding up
# to maxn or less

# Special cases (Knuth's algorithm requires n and m >= 2)
x = [[i] for i in range(maxn+1)]
# Main cases: runs parts fn (maxn^2+maxn)/2 times
for i in range(2, maxn+1):
    for j in range(2, min(p+1, i+1)):
        m = j
        n = i
        a = [0, n-m+1] + [1] * (m-1) + [-1] + [0] * (n-m-1)
        parts(a, n, m)
y = []
# For each partition, add zeros if necessary and then find
# distinct permutations.  Runs distinct_perms function once
# for each partition.
for part in x:
    if len(part) < p:
        y += distinct_perms(part + [0] * (p - len(part)))
    else:
        y += distinct_perms(part)
print(y)
print(len(y))
like image 116
Tony Avatar answered Apr 29 '23 21:04

Tony


Make your own recursive function which do not recurse with an element unless it's possible to make a sum <= 10.

def f(r, n, t, acc=[]):
    if t == 0:
        if n >= 0:
            yield acc
        return
    for x in r:
        if x > n:  # <---- do not recurse if sum is larger than `n`
            break
        for lst in f(r, n-x, t-1, acc + [x]):
            yield lst

t = 5
n = 10
for xs in f(range(n+1), n, 5):
    print xs
like image 35
falsetru Avatar answered Apr 29 '23 21:04

falsetru