Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to compute an expensive high precision sum in python?

My problem is very simple. I would like to compute the following sum.

from __future__ import division
from scipy.misc import comb
import math

for n in xrange(2,1000,10):
    m = 2.2*n/math.log(n)
    print sum(sum(comb(n,a) * comb(n-a,b) * (comb(a+b,a)*2**(-a-b))**m
                    for b in xrange(n+1))
               for a in xrange(1,n+1))

However python gives RuntimeWarning: overflow encountered in multiply and nan as the output and it is also very very slow.

Is there a clever way to do this?

like image 385
graffe Avatar asked Feb 03 '14 18:02

graffe


1 Answers

The reason why you get NaNs is you end up evaluating numbers like

comb(600 + 600, 600) == 3.96509646226102e+359

This is too large to fit into a floating point number:

>>> numpy.finfo(float).max
1.7976931348623157e+308

Take logarithms to avoid it:

from __future__ import division, absolute_import, print_function
from scipy.special import betaln
from scipy.misc import logsumexp
import numpy as np


def binomln(n, k):
    # Assumes binom(n, k) >= 0
    return -betaln(1 + n - k, 1 + k) - np.log(n + 1)


for n in range(2, 1000, 10):
    m = 2.2*n/np.log(n)

    a = np.arange(1, n + 1)[np.newaxis,:]
    b = np.arange(n + 1)[:,np.newaxis]

    v = (binomln(n, a) 
         + binomln(n - a, b) 
         + m*binomln(a + b, a) 
         - m*(a+b) * np.log(2))

    term = np.exp(logsumexp(v))
    print(term)
like image 149
pv. Avatar answered Sep 29 '22 12:09

pv.