I have the problem that I want to count the number of combinations that fulfill the following condition:
a < b < a+d < c < b+d
Where a, b, c
are elements of a list, and d
is a fixed delta.
Here is a vanilla implementation:
def count(l, d):
s = 0
for a in l:
for b in l:
for c in l:
if a < b < a + d < c < b + d:
s += 1
return s
Here is a test:
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
assert(32 == count(l, 4)) # Gone through everything by hand.
How can I speed this up? I am looking at list sizes of 2 Million.
I am dealing with floats in the range of [-pi, pi]. For example, this limits a < 0
.
I have some implementation where I build indices that I use for b
and c
. However, the below code fails some cases. (i.e. This is wrong).
def count(l, d=pi):
low = lower(l, d)
high = upper(l, d)
s = 0
for indA in range(len(l)):
for indB in range(indA+1, low[indA]+1):
s += low[indB] + 1 - high[indA]
return s
def lower(l, d=pi):
'''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
Input must be sorted!
'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] < elem + d:
x += 1
if l[x-1] < elem + d:
ind.append(x-1)
else:
assert(x == length)
ind.append(x)
return ind
def upper(l, d=pi):
''' Returns first index where l[i] > l + d'''
ind = []
x = 0
length = len(l)
for elem in l:
while x < length and l[x] <= elem + d:
x += 1
ind.append(x)
return ind
The original problem is from a well known math/comp-sci competition. The competition asks that you don't post solutions on the net. But it is from two weeks ago.
I can generate the list with this function:
def points(n):
x = 1
y = 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = points(n)
angles.sort()
return count(angles, pi)
If you meant to say "permutations", then you are probably asking the question "how many different ways can I arrange the order of four numbers?" The answer to this question (which you got right) is 24. Here's how to observe this: 1.
There is an approach to your problem that yields an O(n log n)
algorithm. Let X
be the set of values. Now let's fix b
. Let A_b
be the set of values { x in X: b - d < x < b }
and C_b
be the set of values { x in X: b < x < b + d }
. If we can find |{ (x,y) : A_b X C_b | y > x + d }|
fast, we solved the problem.
If we sort X
, we can represent A_b
and C_b
as pointers into the sorted array, because they are contiguous. If we process the b
candidates in non-decreasing order, we can thus maintain these sets using a sliding window algorithm. It goes like this:
X
. Let X = { x_1, x_2, ..., x_n }
, x_1 <= x_2 <= ... <= x_n
.left = i = 1
and set right
so that C_b = { x_{i + 1}, ..., x_right }
. Set count = 0
i
from 1
to n
. In every iteration we find out the number of valid triples (a,b,c)
with b = x_i
. To do that, increase left
and right
as much as necessary so that A_b = { x_left, ..., x_{i-1} }
and C_b = { x_{i + 1}, ..., x_right }
still holds. In the process, you basically add and remove elements from the imaginary sets A_b
and C_b
.
If you remove or add an element to one of the sets, check how many pairs (a, c)
with c > a + d
, a
from A_b
and c
from C_b
you add or destroy (this can be achieved by a simple binary search in the other set). Update count
accordingly so that the invariant count = |{ (x,y) : A_b X C_b | y > x + d }|
still holds.count
in every iteration. This is the final result.The complexity is O(n log n)
.
If you want to solve the Euler problem with this algorithm, you have to avoid floating point issues. I suggest sorting the points by angle using a custom comparison function that uses integer arithmetics only (using 2D vector geometry). Implementing the |a-b| < d
comparisons can also be done using integer operations only. Also, since you are working modulo 2*pi
, you would probably have to introduce three copies of every angle a
: a - 2*pi
, a
and a + 2*pi
. You then only look for b
in the range [0, 2*pi)
and divide the result by three.
UPDATE OP implemented this algorithm in Python. Apparently it contains some bugs but it demonstrates the general idea:
def count(X, d):
X.sort()
count = 0
s = 0
length = len(X)
a_l = 0
a_r = 1
c_l = 0
c_r = 0
for b in X:
if X[a_r-1] < b:
# find boundaries of A s.t. b -d < a < b
while a_r < length and X[a_r] < b:
a_r += 1 # This adds an element to A_b.
ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count += (ind - c_l)
while a_l < length and X[a_l] <= b - d:
a_l += 1 # This removes an element from A_b
ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
if c_l <= ind < c_r:
count -= (c_r - ind)
# Find boundaries of C s.t. b < c < b + d
while c_l < length and X[c_l] <= b:
c_l += 1 # this removes an element from C_b
ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count -= (ind - a_l)
while c_r < length and X[c_r] < b + d:
c_r += 1 # this adds an element to C_b
ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
if a_l <= ind <= a_r:
count += (ind - a_l)
s += count
return s
from bisect import bisect_left, bisect_right
from collections import Counter
def count(l, d):
# cdef long bleft, bright, cleft, cright, ccount, s
s = 0
# Find the unique elements and their counts
cc = Counter(l)
l = sorted(cc.keys())
# Generate a cumulative sum array
cumulative = [0] * (len(l) + 1)
for i, key in enumerate(l, start=1):
cumulative[i] = cumulative[i-1] + cc[key]
# Pregenerate all the left and right lookups
lefthand = [bisect_right(l, a + d) for a in l]
righthand = [bisect_left(l, a + d) for a in l]
aright = bisect_left(l, l[-1] - d)
for ai in range(len(l)):
bleft = ai + 1
# Search only the values of a that have a+d in range
if bleft > aright:
break
# This finds b such that a < b < a + d.
bright = righthand[ai]
for bi in range(bleft, bright):
# This finds the range for c such that a+d < c < b+d.
cleft = lefthand[ai]
cright = righthand[bi]
if cleft != cright:
# Find the count of c elements in the range cleft..cright.
ccount = cumulative[cright] - cumulative[cleft]
s += cc[l[ai]] * cc[l[bi]] * ccount
return s
def testCount():
l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
result = count(l, 4)
assert(32 == result)
testCount()
gets rid of repeated, identical values
iterates over only the required range for a value
uses a cumulative count across two indices to eliminate the loop over c
cache lookups on x + d
This is no longer O(n^3)
but more like O(n^2)`.
This clearly does not yet scale up to 2 million. Here are my times on smaller floating point data sets (i.e. few or no duplicates) using cython to speed up execution:
50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds
Here is my benchmarking code (not including the operative code above):
from math import atan2, pi
def points(n):
x, y = 1, 1
for _ in range(n):
x = (x * 1248) % 32323
y = (y * 8421) % 30103
yield atan2(x - 16161, y - 15051)
def C(n):
angles = sorted(points(n))
return count(angles, pi)
def test_large():
from datetime import datetime
for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
s = datetime.now()
C(n)
elapsed = datetime.now() - s
print("{1}: {0} seconds".format(elapsed, n))
if __name__ == '__main__':
testCount()
test_large()
Since l
is sorted and a < b < c
must be true, you could use itertools.combinations()
to do fewer loops:
sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
Looking at combinations only reduces this loop to 816 iterations.
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32
where the a < b
test is redundant.
1) To reduce amount of iterations on each level you can remove elements from list that dont pass condition on each level
2) Using set
with collections.counter
you can reduce iterations by removing duplicates:
from collections import Counter
def count(l, d):
n = Counter(l)
l = set(l)
s = 0
for a in l:
for b in (i for i in l if a < i < a+d):
for c in (i for i in l if a+d < i < b+d):
s += (n[a] * n[b] * n[c])
return s
>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32
Tested count of iterations (a, b, c) for your version:
>>> count1(l, 4)
18 324 5832
my version:
>>> count2(l, 4)
9 16 7
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With