Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to speed up Poisson pmf function?

My use case is to evaluate Poisson pmf on all points which is less than say, 10, and I would call such function multiple of times with difference lambdas. The lambdas are not known ahead of time so I cannot vectorize lambdas.

I heard from somewhere about a secret trick which is to use _pmf. What is the downside to do so? But still, it is a bit slow, is there any way to improve it without rewriting the pmf in C from scratch?

%timeit scipy.stats.poisson.pmf(np.arange(0,10),3.3)
%timeit scipy.stats.poisson._pmf(np.arange(0,10),3.3)
a = np.arange(0,10)
%timeit scipy.stats.poisson._pmf(a,3.3)

10000 loops, best of 3: 94.5 µs per loop
100000 loops, best of 3: 15.2 µs per loop
100000 loops, best of 3: 13.7 µs per loop

Update

Ok, simply I was just too lazy to write in cython. I had expected there is a faster solution for all discrete distribution that can be evaluated sequentially (iteratively) for consecutive x. E.g. P(X=3) = P(X=2) * lambda / 3 if X ~ Pois(lambda)

Related: Is the build-in probability density functions of `scipy.stat.distributions` slower than a user provided one?

I have less faith in Scipy and Python now. The library function isn't as advanced as what I had expected.

like image 632
colinfang Avatar asked Feb 19 '26 05:02

colinfang


1 Answers

Most of scipy.stats distributions support vectorized evaluation:

>>> poisson.pmf(1, [5, 6, 7, 8])
array([ 0.03368973,  0.01487251,  0.00638317,  0.0026837 ])

This may or may not be fast enough, but you can try taking pmf calls out of the loop.

Re difference between pmf and _pmf: the real work is done in the underscored functions (_pmf, _cdf etc) while the public functions (pmf, cdf) make sure that only valid arguments make it to the _pmf (The output of _pmf is not guaranteed to be meaningful if the arguments are invalid, so use on your own risk).

>>> poisson.pmf(1, -1)
nan
>>> poisson._pmf(1, -1)
/home/br/virtualenvs/scipy-dev/local/lib/python2.7/site-packages/scipy/stats/_discrete_distns.py:432: RuntimeWarning: invalid value encountered in log
  Pk = k*log(mu)-gamln(k+1) - mu
nan

Further details: https://github.com/scipy/scipy/blob/master/scipy/stats/_distn_infrastructure.py#L2721

like image 162
ev-br Avatar answered Feb 20 '26 19:02

ev-br



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!