Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to count number of combinations?

I have the problem that I want to count the number of combinations that fulfill the following condition:

 a < b < a+d < c < b+d

Where a, b, c are elements of a list, and d is a fixed delta.

Here is a vanilla implementation:

def count(l, d):
    s = 0
    for a in l:
        for b in l:
            for c in l:
                if a < b < a + d < c < b + d:
                    s += 1
    return s

Here is a test:

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    assert(32 == count(l, 4)) # Gone through everything by hand.

Question

How can I speed this up? I am looking at list sizes of 2 Million.

Supplementary Information

I am dealing with floats in the range of [-pi, pi]. For example, this limits a < 0.

What I have so far:

I have some implementation where I build indices that I use for b and c. However, the below code fails some cases. (i.e. This is wrong).

def count(l, d=pi):
    low = lower(l, d)
    high = upper(l, d)
    s = 0
    for indA in range(len(l)):
            for indB in range(indA+1, low[indA]+1):
                    s += low[indB] + 1 - high[indA]
    return s

def lower(l, d=pi):
    '''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
    Input must be sorted!
    '''
    ind = []
    x = 0
    length = len(l)
    for  elem in l:
        while x < length and l[x] < elem + d:
            x += 1
        if l[x-1] < elem + d:
            ind.append(x-1)
        else:
            assert(x == length)
            ind.append(x)
    return ind


def upper(l, d=pi):
    ''' Returns first index where l[i] > l + d'''
    ind = []
    x = 0
    length = len(l)
    for elem in l:
        while x < length and l[x] <= elem + d:
            x += 1
        ind.append(x)
    return ind

Original Problem

The original problem is from a well known math/comp-sci competition. The competition asks that you don't post solutions on the net. But it is from two weeks ago.

I can generate the list with this function:

def points(n):
    x = 1
    y = 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = points(n)
    angles.sort()
    return count(angles, pi)
like image 294
Unapiedra Avatar asked Feb 08 '14 15:02

Unapiedra


People also ask

How many combinations of 4 items are there?

If you meant to say "permutations", then you are probably asking the question "how many different ways can I arrange the order of four numbers?" The answer to this question (which you got right) is 24. Here's how to observe this: 1.


4 Answers

There is an approach to your problem that yields an O(n log n) algorithm. Let X be the set of values. Now let's fix b. Let A_b be the set of values { x in X: b - d < x < b } and C_b be the set of values { x in X: b < x < b + d }. If we can find |{ (x,y) : A_b X C_b | y > x + d }| fast, we solved the problem.

If we sort X, we can represent A_b and C_b as pointers into the sorted array, because they are contiguous. If we process the b candidates in non-decreasing order, we can thus maintain these sets using a sliding window algorithm. It goes like this:

  1. sort X. Let X = { x_1, x_2, ..., x_n }, x_1 <= x_2 <= ... <= x_n.
  2. Set left = i = 1 and set right so that C_b = { x_{i + 1}, ..., x_right }. Set count = 0
  3. Iterate i from 1 to n. In every iteration we find out the number of valid triples (a,b,c) with b = x_i. To do that, increase left and right as much as necessary so that A_b = { x_left, ..., x_{i-1} } and C_b = { x_{i + 1}, ..., x_right } still holds. In the process, you basically add and remove elements from the imaginary sets A_b and C_b. If you remove or add an element to one of the sets, check how many pairs (a, c) with c > a + d, a from A_b and c from C_b you add or destroy (this can be achieved by a simple binary search in the other set). Update count accordingly so that the invariant count = |{ (x,y) : A_b X C_b | y > x + d }| still holds.
  4. sum up the values of count in every iteration. This is the final result.

The complexity is O(n log n).

If you want to solve the Euler problem with this algorithm, you have to avoid floating point issues. I suggest sorting the points by angle using a custom comparison function that uses integer arithmetics only (using 2D vector geometry). Implementing the |a-b| < d comparisons can also be done using integer operations only. Also, since you are working modulo 2*pi, you would probably have to introduce three copies of every angle a: a - 2*pi, a and a + 2*pi. You then only look for b in the range [0, 2*pi) and divide the result by three.

UPDATE OP implemented this algorithm in Python. Apparently it contains some bugs but it demonstrates the general idea:

def count(X, d):
    X.sort()
    count = 0
    s = 0
    length = len(X)
    a_l = 0
    a_r = 1
    c_l = 0
    c_r = 0
    for b in X:
        if X[a_r-1] < b:
            # find boundaries of A s.t. b -d < a < b
            while a_r < length and X[a_r] < b:
                a_r += 1  # This adds an element to A_b. 
                ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count += (ind - c_l)
            while a_l < length and X[a_l] <= b - d:
                a_l += 1  # This removes an element from A_b
                ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count -= (c_r - ind)
            # Find boundaries of C s.t. b < c < b + d
            while c_l < length and X[c_l] <= b:
                c_l += 1  # this removes an element from C_b
                ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count -= (ind - a_l)
            while c_r  < length and X[c_r] < b + d:
                c_r += 1 # this adds an element to C_b
                ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count += (ind - a_l)
            s += count
    return s
like image 113
Niklas B. Avatar answered Oct 08 '22 20:10

Niklas B.


from bisect import bisect_left, bisect_right
from collections import Counter

def count(l, d):
    # cdef long bleft, bright, cleft, cright, ccount, s
    s = 0

    # Find the unique elements and their counts
    cc = Counter(l)

    l = sorted(cc.keys())

    # Generate a cumulative sum array
    cumulative = [0] * (len(l) + 1)
    for i, key in enumerate(l, start=1):
        cumulative[i] = cumulative[i-1] + cc[key]

    # Pregenerate all the left and right lookups
    lefthand = [bisect_right(l, a + d) for a in l]
    righthand = [bisect_left(l, a + d) for a in l]

    aright = bisect_left(l, l[-1] - d)
    for ai in range(len(l)):
        bleft = ai + 1
        # Search only the values of a that have a+d in range
        if bleft > aright:
            break
        # This finds b such that a < b < a + d.
        bright = righthand[ai]
        for bi in range(bleft, bright):
            # This finds the range for c such that a+d < c < b+d.
            cleft = lefthand[ai]
            cright = righthand[bi]
            if cleft != cright:
                # Find the count of c elements in the range cleft..cright.
                ccount = cumulative[cright] - cumulative[cleft]
                s += cc[l[ai]] * cc[l[bi]] * ccount
    return s

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    result = count(l, 4)
    assert(32 == result)

testCount()
  1. gets rid of repeated, identical values

  2. iterates over only the required range for a value

  3. uses a cumulative count across two indices to eliminate the loop over c

  4. cache lookups on x + d

This is no longer O(n^3) but more like O(n^2)`.

This clearly does not yet scale up to 2 million. Here are my times on smaller floating point data sets (i.e. few or no duplicates) using cython to speed up execution:

50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds

Here is my benchmarking code (not including the operative code above):

from math import atan2, pi

def points(n):
    x, y = 1, 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = sorted(points(n))
    return count(angles, pi)

def test_large():
    from datetime import datetime
    for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
        s = datetime.now()
        C(n)
        elapsed = datetime.now() - s
        print("{1}: {0} seconds".format(elapsed, n))

if __name__ == '__main__':
    testCount()
    test_large()
like image 25
15 revs Avatar answered Oct 08 '22 18:10

15 revs


Since l is sorted and a < b < c must be true, you could use itertools.combinations() to do fewer loops:

sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)

Looking at combinations only reduces this loop to 816 iterations.

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32

where the a < b test is redundant.

like image 41
Martijn Pieters Avatar answered Oct 08 '22 18:10

Martijn Pieters


1) To reduce amount of iterations on each level you can remove elements from list that dont pass condition on each level
2) Using set with collections.counter you can reduce iterations by removing duplicates:

from collections import Counter
def count(l, d):
    n = Counter(l)
    l = set(l)
    s = 0
    for a in l:
        for b in (i for i in l if a < i < a+d):
            for c in (i for i in l if a+d < i < b+d):
                s += (n[a] * n[b] * n[c])
    return s

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32

Tested count of iterations (a, b, c) for your version:

>>> count1(l, 4)
18 324 5832

my version:

>>> count2(l, 4)
9 16 7
like image 30
ndpu Avatar answered Oct 08 '22 20:10

ndpu