Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Haskell Performance Optimization

I am writing code to find nth Ramanujan-Hardy number. Ramanujan-Hardy number is defined as

n = a^3 + b^3 = c^3 + d^3

means n can be expressed as sum of two cubes.

I wrote the following code in haskell:

-- my own implementation for cube root. Expected time complexity is O(n^(1/3))
cube_root n = chelper 1 n
                where
                        chelper i n = if i*i*i > n then (i-1) else chelper (i+1) n

-- It checks if the given number can be expressed as a^3 + b^3 = c^3 + d^3 (is Ramanujan-Hardy number?)
is_ram n = length [a| a<-[1..crn], b<-[(a+1)..crn], c<-[(a+1)..crn], d<-[(c+1)..crn], a*a*a + b*b*b == n && c*c*c + d*d*d == n] /= 0
        where
                crn = cube_root n

-- It finds nth Ramanujan number by iterating from 1 till the nth number is found. In recursion, if x is Ramanujan number, decrement n. else increment x. If x is 0, preceding number was desired Ramanujan number.    
ram n = give_ram 1 n
        where
                give_ram x 0 = (x-1)
                give_ram x n = if is_ram x then give_ram (x+1) (n-1) else give_ram (x+1) n

In my opinion, time complexity to check if a number is Ramanujan number is O(n^(4/3)).

On running this code in ghci, it is taking time even to find 2nd Ramanujan number.

What are possible ways to optimize this code?

like image 580
doptimusprime Avatar asked Mar 15 '23 23:03

doptimusprime


1 Answers

First a small clarification of what we're looking for. A Ramanujan-Hardy number is one which may be written two different ways as a sum of two cubes, i.e. a^3+b^3 = c^3 + d^3 where a < b and a < c < d.

An obvious idea is to generate all of the cube-sums in sorted order and then look for adjacent sums which are the same.

Here's a start - a function which generates all of the cube sums with a given first cube:

cubes a = [ (a^3+b^3, a, b) | b <- [a+1..] ]

All of the possible cube sums in order is just:

allcubes = sort $ concat [ cubes 1, cubes 2, cubes 3, ... ]

but of course this won't work since concat and sort don't work on infinite lists. However, since cubes a is an increasing sequence we can sort all of the sequences together by merging them:

allcubes = cubes 1 `merge` cubes 2 `merge` cubes 3 `merge` ...

Here we are taking advantage of Haskell's lazy evaluation. The definition of merge is just:

 merge [] bs = bs
 merge as [] = as
 merge as@(a:at) bs@(b:bt)
  = case compare a b of
      LT -> a : merge at bs
      EQ -> a : b : merge at bt
      GT -> b : merge as bt

We still have a problem since we don't know where to stop. We can solve that by having cubes a initiate cubes (a+1) at the appropriate time, i.e.

cubes a = ...an initial part... ++ (...the rest... `merge` cubes (a+1) )

The definition is accomplished using span:

 cubes a = first ++ (rest `merge` cubes (a+1))
   where
     s = (a+1)^3 + (a+2)^3
     (first, rest) = span (\(x,_,_) -> x < s) [ (a^3+b^3,a,b) | b <- [a+1..]]

So now cubes 1 is the infinite series of all the possible sums a^3 + b^3 where a < b in sorted order.

To find the Ramanujan-Hardy numbers, we just group adjacent elements of the list together which have the same first component:

 sameSum (x,a,b) (y,c,d) = x == y
 rjgroups = groupBy sameSum $ cubes 1

The groups we are interested in are those whose length is > 1:

 rjnumbers = filter (\g -> length g > 1) rjgroups

Thre first 10 solutions are:

ghci> take 10 rjnumbers

[(1729,1,12),(1729,9,10)]
[(4104,2,16),(4104,9,15)]
[(13832,2,24),(13832,18,20)]
[(20683,10,27),(20683,19,24)]
[(32832,4,32),(32832,18,30)]
[(39312,2,34),(39312,15,33)]
[(40033,9,34),(40033,16,33)]
[(46683,3,36),(46683,27,30)]
[(64232,17,39),(64232,26,36)]
[(65728,12,40),(65728,31,33)]
like image 126
ErikR Avatar answered Mar 22 '23 00:03

ErikR