Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient way of sampling from indices of a Numpy array?

I'd like to sample from indices of a 2D Numpy array, considering that each index is weighted by the number inside of that array. The way I know it is with numpy.random.choice however that does not return the index but the number itself. Is there any efficient way of doing so?

Here is my code:

import numpy as np
A=np.arange(1,10).reshape(3,3)
A_flat=A.flatten()
d=np.random.choice(A_flat,size=10,p=A_flat/float(np.sum(A_flat)))
print d
like image 507
Cupitor Avatar asked Oct 03 '22 12:10

Cupitor


2 Answers

You could do something like:

import numpy as np

def wc(weights):
    cs = np.cumsum(weights)
    idx = cs.searchsorted(np.random.random() * cs[-1], 'right')
    return np.unravel_index(idx, weights.shape)

Notice that the cumsum is the slowest part of this, so if you need to do this repeatidly for the same array I'd suggest computing the cumsum ahead of time and reusing it.

like image 101
Bi Rico Avatar answered Oct 05 '22 04:10

Bi Rico


To expand on my comment: Adapting the weighted choice method presented here https://stackoverflow.com/a/10803136/553404

def weighted_choice_indices(weights):
    cs = np.cumsum(weights.flatten())/np.sum(weights)
    idx = np.sum(cs < np.random.rand())
    return np.unravel_index(idx, weights.shape)
like image 36
YXD Avatar answered Oct 05 '22 02:10

YXD