Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting scipy's rv_discrete to work with floating point values?

I'm trying to define my own discrete distribution. The code I have works for integer values but not for decimal values. For example, this works:

>>> from scipy.stats import rv_discrete
>>> probabilities = [0.2, 0.5, 0.3]
>>> values = [1, 2, 3]
>>> distrib = rv_discrete(values=(values, probabilities))
>>> print distrib.rvs(size=10)
[1 3 3 2 2 2 2 2 1 3]

But if I use decimal values, it doesn't work:

>>> from scipy.stats import rv_discrete
>>> probabilities = [0.2, 0.5, 0.3]
>>> values = [.1, .2, .3]
>>> distrib = rv_discrete(values=(values, probabilities))
>>> print distrib.rvs(size=10)
[0 0 0 0 0 0 0 0 0 0]

Thanks..

like image 495
Ben S. Avatar asked Mar 12 '23 04:03

Ben S.


1 Answers

Per stats.rv_discrete's doc string:

values : tuple of two array_like, optional (xk, pk) where xk are integers with non-zero probabilities pk with sum(pk) = 1.

(my emphasis). So the discrete distributions created by rv_discrete must use integer values. However, it is not hard to map those integer values to floats by using the rvs values as integer indices into values:

In [4]: values = np.array([0.1, 0.2, 0.3])

In [5]: idx = distrib.rvs(size=10); idx
Out[5]: array([1, 1, 0, 0, 1, 1, 0, 2, 1, 1])

In [6]: values[idx]
Out[6]: array([ 0.2,  0.2,  0.1,  0.1,  0.2,  0.2,  0.1,  0.3,  0.2,  0.2])

Thus you could use:

import numpy as np
import scipy.stats as stats
np.random.seed(2016)
probabilities = np.array([0.2, 0.5, 0.3])
values = np.array([0.1, 0.2, 0.3])
distrib = stats.rv_discrete(values=(range(len(probabilities)), probabilities))
idx = distrib.rvs(size=10)
result = values[idx]
print(result)
# [ 0.3  0.3  0.3  0.3  0.2  0.2  0.2  0.3  0.3  0.2]
like image 101
unutbu Avatar answered Mar 25 '23 02:03

unutbu