Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speed up calculation of partitions in Haskell

I'm trying to solve Euler problem 78, which basically asks for the first number where the partition function p(n) is divisible by 1000000.

I use Euler's recursive fomula based on pentagonal numbers (calculated here in pents together with the proper sign). Here is my code:

ps = 1 : map p [1..] where
  p n = sum $ map getP $ takeWhile ((<= n).fst) pents where
    getP (pent,sign) = sign * (ps !! (n-pent)) 

pents = zip (map (\n -> (3*n-1)*n `div` 2) $ [1..] >>= (\x -> [x,-x]))
            (cycle [1,1,-1,-1])

While ps seems to produce the correct results, it is too slow. Is there a way to speed the calculation up, or do I need a completely different approach?

like image 755
Landei Avatar asked Apr 25 '11 14:04

Landei


4 Answers

xs !! n has a linear complexity. You should rather try using a logarithmic or constant-access data structure.

Edit : here is a quick implementation I came up with by copying a similar one by augustss :

psOpt x = psArr x
  where psCall 0 = 1
        psCall n = sum $ map getP $ takeWhile ((<= n).fst) pents where
          getP (pent,sign) = sign * (psArr (n-pent))
        psArr n = if n > ncache then psCall n else psCache ! n
        psCache = listArray (0,ncache) $ map psCall [0..ncache]

In ghci, I observe no spectacular speedup over your list version. No luck !

Edit : Indeed, with -O2 as suggested by Chris Kuklewicz, this solution is eight times faster than your for n=5000. Combined with Hammar insight of doing sums modulo 10^6, I get a solution that is fast enough (find the hopefully correct answer in about 10 seconds on my machine):

import Data.List (find)
import Data.Array 

ps = 1 : map p [1..] where
  p n = sum $ map getP $ takeWhile ((<= n).fst) pents where
    getP (pent,sign) = sign * (ps !! (n-pent)) 

summod li = foldl (\a b -> (a + b) `mod` 10^6) 0 li

ps' = 1 : map p [1..] where
  p n = summod $ map getP $ takeWhile ((<= n).fst) pents where
    getP (pent,sign) = sign * (ps !! (n-pent)) 

ncache = 1000000

psCall 0 = 1
psCall n = summod $ map getP $ takeWhile ((<= n).fst) pents
  where getP (pent,sign) = sign * (psArr (n-pent))
psArr n = if n > ncache then psCall n else psCache ! n
psCache = listArray (0,ncache) $ map psCall [0..ncache]

pents = zip (map (\n -> ((3*n-1)*n `div` 2) `mod` 10^6) $ [1..] >>= (\x -> [x,-x]))
            (cycle [1,1,-1,-1])

(I broke the psCache abstraction, so you should use psArr instead of psOpt; this ensures that different call to psArr will reuse the same memoized array. This is useful when you write find ((== 0) . ...)... well, I thought it was better not to publish the complete solution.)

Thanks to all for the additional advice.

like image 109
gasche Avatar answered Oct 19 '22 07:10

gasche


Well, one observation is that since you're only interested in map (`mod` 10^6) ps, you might be able to avoid having to do calculations on huge numbers.

like image 2
hammar Avatar answered Oct 19 '22 06:10

hammar


I haven't done that Euler problem, but usually with Euler problems there is a clever trick to speed up the calculation.

When I see this:

sum $ map getP $ takeWhile ((<=n).fst) pents

I can't help but think there must be a smarter way than to invoke sum . map getP every time you compute an element of ps.

Now that I look at it...wouldn't it be faster to perform the sum first, and then multiply, rather than performing the multiply (inside of getP) for every element and then summing?

Normally I'd take a deeper look and provide working code; but it's an Euler problem (wouldn't want to spoil it!), so I'll just stop here with these few thoughts.

like image 1
Dan Burton Avatar answered Oct 19 '22 07:10

Dan Burton


Inspired by your question, I have used your approach to solve Euler 78 with Haskell. So I can give you some performance hints.

Your cached list of pents should be good as it is.

Pick some large number maxN to bound your search for (p n).

The plan is to use an (Array Int64 Integer) to memorize the result of (p n), with lower bound 0 and upper bound maxN. This requires defining the array in terms of 'p' and 'p' in terms of the array, they are mutually recursively defined:

Redefine (p n) to (pArray n) to look up recursive calls to 'p' in the array A.

Use your new pArray with Data.Array.IArray.listArray to create the array A.

Definitely compile with 'ghc -O2'. This runs in 13 seconds here.

like image 1
Chris Kuklewicz Avatar answered Oct 19 '22 05:10

Chris Kuklewicz