Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Vectorizing `numpy.random.choice` for given 2D array of probabilities along an axis

Numpy has the random.choice function, which allows you to sample from a categorical distribution. How would you repeat this over an axis? To illustrate what I mean, here is my current code:

categorical_distributions = np.array([
    [.1, .3, .6],
    [.2, .4, .4],
])
_, n = categorical_distributions.shape
np.array([np.random.choice(n, p=row)
          for row in categorical_distributions])

Ideally, I would like to eliminate the for loop.

like image 417
ethanabrooks Avatar asked Dec 08 '17 20:12

ethanabrooks


1 Answers

Here's one vectorized way to get the random indices per row, with a as the 2D array of probabilities -

(a.cumsum(1) > np.random.rand(a.shape[0])[:,None]).argmax(1)

Generalizing to cover both along the rows and columns for 2D array -

def random_choice_prob_index(a, axis=1):
    r = np.expand_dims(np.random.rand(a.shape[1-axis]), axis=axis)
    return (a.cumsum(axis=axis) > r).argmax(axis=axis)

Let's verify with the given sample by running it over a million times -

In [589]: a = np.array([
     ...:     [.1, .3, .6],
     ...:     [.2, .4, .4],
     ...: ])

In [590]: choices = [random_choice_prob_index(a)[0] for i in range(1000000)]

# This should be close to first row of given sample
In [591]: np.bincount(choices)/float(len(choices))
Out[591]: array([ 0.099781,  0.299436,  0.600783])

Runtime test

Original loopy way -

def loopy_app(categorical_distributions):
    m, n = categorical_distributions.shape
    out = np.empty(m, dtype=int)
    for i,row in enumerate(categorical_distributions):
        out[i] = np.random.choice(n, p=row)
    return out

Timings on bigger array -

In [593]: a = np.array([
     ...:     [.1, .3, .6],
     ...:     [.2, .4, .4],
     ...: ])

In [594]: a_big = np.repeat(a,100000,axis=0)

In [595]: %timeit loopy_app(a_big)
1 loop, best of 3: 2.54 s per loop

In [596]: %timeit random_choice_prob_index(a_big)
100 loops, best of 3: 6.44 ms per loop
like image 175
Divakar Avatar answered Oct 27 '22 00:10

Divakar