I have recently come across the following problem: We're given an integer sequence x_i (x_i < 2^60)
of n (n < 10^5)
integers and an integer S (S < 2^60)
find the smallest integer a
such that the following holds:
.
For example:
x = [1, 2, 5, 10, 50, 100]
S = 242
Possible solutions for a
are 21, 23, 37, 39, but the smallest is 21.
(1^21) + (2^21) + (5^21) + (10^21) + (50^21) + (100^21)
= 20 + 23 + 16 + 31 + 39 + 113
= 242
A Simple solution is to traverse each element and check if there's another number whose XOR with it is equal to x. This solution takes O(n2) time. An efficient solution to this problem takes O(n) time. The idea is based on the fact that arr[i] ^ arr[j] is equal to x if and only if arr[i] ^ x is equal to arr[j].
We can find the XOR from index l to r using the formula: if l is not zero XOR = prefix[r] ^ prefix[l-1] else XOR = prefix[r]. After this, all we have to do is, to sum up, the XOR values of all the sub-arrays.
Approach: In order to find the XOR of all elements in the array, we simply iterate through the array and find the XOR using '^' operator.
One can build the result up bit by bit from the bottom. Starting with the lowest bit, try 0 and 1 as the lowest bit of a
, and see if the lowest bit of the sum-xor matches the corresponding bit of S. Then try the next lowest bit, propagating any carry from the previous step.
Following this algorithm, there may be 0, 1 or 2 choices for each bit of a
, so in the worst case we may need to explore different branches and pick the one that gives the smallest result. To avoid exponential behavior, we cache previously seen results for the carry at a certain bit. That yields a worst-case complexity of O(kn) where k is the maximum number of bits in the result, and n is the maximum value of the carry given the input list is of length n.
Here's some Python code that implements this:
max_shift = 80
def xor_sum0(xs, S, shift, carry, cache, sums):
if shift >= max_shift:
return 1e100 if carry else 0
key = shift, carry
if key in cache:
return cache[key]
best = 1e100
for i in xrange(2):
ss = sums[i][shift] + carry
if ss & 1 == (S >> shift) & 1:
best = min(best, i + 2 * xor_sum0(xs, S, shift + 1, ss >> 1, cache, sums))
cache[key] = best
return cache[key]
def xor_sum(xs, S):
sums = [
[sum(((x >> sh) ^ i) & 1 for x in xs) for sh in xrange(max_shift)]
for i in xrange(2)]
return xor_sum0(xs, S, 0, 0, dict(), sums)
In the case there's no solution, the code returns a large (>=1e100) floating point number.
And here's a test that picks random values in the ranges you gave, picks a random a
and computes S, and then solves. Note that sometimes the code finds a smaller a
than the one that was used to compute S since values of a
are not always unique.
import random
xs = [random.randrange(0, 1 << 61) for _ in xrange(random.randrange(10 ** 5))]
a_original = random.randrange(1 << 61)
S = sum(x ^ a_original for x in xs)
print S
print xs
a = xor_sum(xs, S)
assert a < 1e100
print 'a:', a
print 'original a:', a_original
assert a <= a_original
print 'S', S
print 'SUM', sum(x^a for x in xs)
assert sum(x^a for x in xs) == S
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With