Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Ordered pair for a number

For a given number N find total possible ordered pair (x,y) such that x and y are less than or equal to n and sum of digits of x is less than sum of digits of y

for example n=6: there are 21 possible ordered pair which are [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (2, 3), (2, 4), (2, 5), (2, 6), (3, 4), (3, 5), (3, 6), (4, 5), (4, 6), (5, 6)]

here x is always less than y and sum of digits of x is also less than sum of digits of y and both x and y are equal or less than N.Here is my naive approach but this is pretty slow and works fine till N=10000 after that it performs badly.

from itertools import permutations
n=100
lis=list(range(n+1))
y=list(i for i in permutations(lis,2) if i[0]<i[1] and sum(list(map(int,
(list(str(i[0]))))))<sum(list(map(int,(list(str(i[1])))))))
print(len(y))

One using generators

from itertools import permutations
for _ in range(int(input())):
    n=1000
    lis=range(n+1)
    y=(i for i in permutations(lis,2) if i[0]<i[1] and sum(list(map(int,
     (list(str(i[0]))))))<sum(list(map(int,(list(str(i[1])))))))
    print (sum(1 for _ in y))

A better improved version:

from itertools import permutations
for _ in range(int(input())):
    n=1000
    lis=range(n+1)
    y=(i for i in permutations(lis,2) if i[0]<i[1] and sum(map(int,(str(i[0]))))<sum(map(int,(list(str(i[1]))))))
    print (sum(1 for _ in y))

Is there a better approach to tackle this problem?

like image 715
Om Sharma Avatar asked Oct 29 '22 23:10

Om Sharma


1 Answers

How the code works

This is almost exclusively algorithmic improvements over your method. It may be faster to use generators or list comprehensions, but you'd have to profile it to check. The algorithm works as follows:

  1. Precompute the digit sums of 1 - N.
  2. Group the numbers 1 - N by their digit sum. We have an object that looks like this. Thus if we want to get numbers with digit sum >2, we only need to count the numbers after the third row.

1: 1, 10
2: 2, 11, 20
3: 3, 12, 21, 30 ...

  1. Observe that the numbers within each row are in sorted order. If our number is 12, we only need to look at numbers after 12. We can find the 12s in each row with a binary search.

Overall, this is a ~20x improvement over your algorithm, with O(N) memory cost

The code

import time
import bisect
import itertools

N = 6

def sum_digits(n):
    # stolen from here: https://stackoverflow.com/questions/14939953/sum-the-digits-of-a-number-python
    # there may be a faster way of doing this based on the fact that you're doing this over 1 .. N
    r = 0
    while n:
        r, n = r + n % 10, n // 10
    return r        

t = time.time()
# trick 1: precompute all of the digit sums. This cuts the time to ~0.3s on N = 1000
digit_sums = [sum_digits(i) for i in range(N+1)]
digit_sum_map = {}

# trick 2: group the numbers by the digit sum, so we can iterate over all the numbers with a given digit sum very quickly
for i, key in enumerate(digit_sums):
    try:
        digit_sum_map[key].append(i)
    except KeyError:
        digit_sum_map[key] = [i]
max_digit_sum = max(digit_sum_map.keys())

# trick 3: note that we insert elements into the digit_sum_map in order. thus we can binary search within the map to find
# where to start counting from. 
result = []
for i in range(N):
    for ds in range(digit_sums[i] + 1, max_digit_sum + 1):
        result.extend(zip(itertools.repeat(i), digit_sum_map[ds][bisect.bisect_left(digit_sum_map[ds], i):]))

print('took {} s, answer is {} for N = {}'.format(time.time() - t, len(result), N))
# took 0.0 s, answer is 21 for N = 6
# took 0.11658287048339844 s, answer is 348658 for N = 1000
# took 8.137377977371216 s, answer is 33289081 for N = 10000

# for reference, your last method takes 2.45 s on N = 1000 on my machine
like image 97
c2huc2hu Avatar answered Nov 15 '22 05:11

c2huc2hu