Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Better way of taking random sample without replacement from list in Haskell

Tags:

list

haskell

I need to take a random sample without replacement (each element only occurring once in the sample) from a longer list. I'm using the code below, but now I'd like to know:

  1. Is there a library function that does this?
  2. How can I improve this code? (I'm a Haskell beginner, so this would be useful even if there is a library function).

The purpose of the sampling is to be able to generalize findings from analyzing the sample to the population.

import System.Random

-- | Take a random sample without replacement of size size from a list.
takeRandomSample :: Int -> Int -> [a] -> [a]
takeRandomSample seed size xs
    | size < hi  = subset xs rs
    | otherwise = error "Sample size must be smaller than population."
    where
        rs = randomSample seed size lo hi
        lo = 0
        hi = length xs - 1

getOneRandomV g lo hi = randomR (lo, hi) g

rsHelper size lo hi g x acc
    | x `notElem` acc && length acc < size = rsHelper size lo hi new_g new_x (x:acc)
    | x `elem` acc && length acc < size = rsHelper size lo hi new_g new_x acc
    | otherwise = acc
    where (new_x, new_g) = getOneRandomV g lo hi

-- | Get a random sample without replacement of size size between lo and hi.
randomSample seed size lo hi = rsHelper size lo hi g x [] where
(x, g)  = getOneRandomV (mkStdGen seed) lo hi

subset l = map (l !!) 
like image 361
ajerneck Avatar asked Dec 08 '12 16:12

ajerneck


2 Answers

Here's a quick 'back-of-the-envelope' implementation of what Daniel Fischer suggested in his comment, using my preferred PRNG (mwc-random):

{-# LANGUAGE BangPatterns #-}

module Sample (sample) where

import Control.Monad.Primitive
import Data.Foldable (toList)
import qualified Data.Sequence as Seq
import System.Random.MWC

sample :: PrimMonad m => [a] -> Int -> Gen (PrimState m) -> m [a]
sample ys size = go 0 (l - 1) (Seq.fromList ys) where
    l = length ys
    go !n !i xs g | n >= size = return $! (toList . Seq.drop (l - size)) xs
                  | otherwise = do
                      j <- uniformR (0, i) g
                      let toI  = xs `Seq.index` j
                          toJ  = xs `Seq.index` i
                          next = (Seq.update i toI . Seq.update j toJ) xs
                      go (n + 1) (i - 1) next g
{-# INLINE sample #-}

This is pretty much a (terse) functional rewrite of R's internal C version of sample() as it's called without replacement.

sample is just a wrapper over a recursive worker function that incrementally shuffles the population until the desired sample size is reached, returning only that many shuffled elements. Writing the function like this ensures that GHC can inline it.

It's easy to use:

*Main> create >>= sample [1..100] 10
[51,94,58,3,91,70,19,65,24,53]

A production version might want to use something like a mutable vector instead of Data.Sequence in order to cut down on time spent doing GC.

like image 50
jtobin Avatar answered Sep 18 '22 00:09

jtobin


I think a standard way to do this is to keep a fixed-size buffer initialized with the first N elements, and for each i'th element, i >= N, do this:

  1. Pick a random number, j, between 0 and i.
  2. If j < N then replace the j'th element in the buffer with the current one.

You can prove correctness by induction:

This clearly generates a random sample (I assume order is irrelevant) if you only have N elements. Now suppose it's true up to the i'th element. This means that the probability of any element being in the buffer is N/(i+1) (I start counting at 0).

After picking the random number, the probability that the i+1'th element is in the buffer is N/(i+2) (j is between 0 and i+1, and N of those end up in the buffer). What about the others?

P(k'th element is in the buffer after processing the i+1'th) =
P(k'th element was in the buffer before)*P(k'th element is not replaced) =
N/(i+1) * (1-1/(i+2)) =
N/(i+2)

Here's some code that does it, in sample-size space, using the standard (slow) System.Random.

import Control.Monad (when)                                                                                                       
import Data.Array                                                                                                                 
import Data.Array.ST                                                                                                              
import System.Random (RandomGen, randomR)                                                                                         

sample :: RandomGen g => g -> Int -> [Int] -> [Int]                                                                               
sample g size xs =                                                                                                                
  if size < length xs                                                                                                             
  then error "sample size must be >= input length"                                                                                
  else elems $ runSTArray $ do                                                                                                    
    arr <- newListArray (0, size-1) pre                                                                                         
    loop arr g size post                                                                                                          
  where                                                                                                                           
    (pre, post) = splitAt size xs                                                                                                 
    loop arr g i [] = return arr                                                                                                  
    loop arr g i (x:xt) = do                                                                                                      
      let (j, g') = randomR (0, i) g                                                                                              
      when (j < size) $ writeArray arr j x                                                                                        
      loop arr g' (i+1) xt                                                                                                        
like image 33
Itai Zukerman Avatar answered Sep 19 '22 00:09

Itai Zukerman