Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Random sample from a very long iterable, in python

I have a long python generator that I want to "thin out" by randomly selecting a subset of values. Unfortunately, random.sample() will not work with arbitrary iterables. Apparently, it needs something that supports the len() operation (and perhaps non-sequential access to the sequence, but that's not clear). And I don't want to build an enormous list just so I can thin it out.

As a matter of fact, it is possible to sample from a sequence uniformly in one pass, without knowing its length-- there's a nice algorithm in Programming perl that does just that (edit: "reservoir sampling", thanks @user2357112!). But does anyone know of a standard python module that provides this functionality?

Demo of the problem (Python 3)

>>> import itertools, random
>>> random.sample(iter("abcd"), 2)
...
TypeError: Population must be a sequence or set.  For dicts, use list(d).

On Python 2, the error is more transparent:

Traceback (most recent call last):
  File "<pyshell#12>", line 1, in <module>
    random.sample(iter("abcd"), 2)
  File "/usr/local/Cellar/python/2.7.9/Frameworks/Python.framework/Versions/2.7/lib/python2.7/random.py", line 321, in sample
    n = len(population)
TypeError: object of type 'iterator' has no len()

If there's no alternative to random.sample(), I'd try my luck with wrapping the generator into an object that provides a __len__ method (I can find out the length in advance). So I'll accept an answer that shows how to do that cleanly.

like image 821
alexis Avatar asked Dec 11 '22 18:12

alexis


1 Answers

Since you know the length the data returned by your iterable, you can use xrange() to quickly generate indices into your iterable. Then you can just run the iterable until you've grabbed all of the data:

import random

def sample(it, length, k):
    indices = random.sample(xrange(length), k)
    result = [None]*k
    for index, datum in enumerate(it):
        if index in indices:
            result[indices.index(index)] = datum
    return result

print sample(iter("abcd"), 4, 2)

In the alternative, here is an implementation of resevior sampleing using "Algorithm R":

import random

def R(it, k):
    '''https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_R'''
    it = iter(it)
    result = []
    for i, datum in enumerate(it):
        if i < k:
            result.append(datum)
        else:
            j = random.randint(0, i-1)
            if j < k:
                result[j] = datum
    return result

print R(iter("abcd"), 2)

Note that algorithm R doesn't provide a random order for the results. In the example given, 'b' will never precede 'a' in the results.

like image 159
Robᵩ Avatar answered Jan 06 '23 01:01

Robᵩ