Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scipy Multivariate Normal: How to draw deterministic samples?

I am using Scipy.stats.multivariate_normal to draw samples from a multivariate normal distribution. Like this:

from scipy.stats import multivariate_normal
# Assume we have means and covs
mn = multivariate_normal(mean = means, cov = covs)
# Generate some samples
samples = mn.rvs()

The samples are different at every run. How do I get always the same sample? I was expecting something like:

mn = multivariate_normal(mean = means, cov = covs, seed = aNumber)

or

samples = mn.rsv(seed = aNumber)
like image 356
k88074 Avatar asked Feb 05 '23 19:02

k88074


1 Answers

There are two ways:

  1. The rvs() method accepts a random_state argument. Its value can be an integer seed, or an instance of numpy.random.Generator or numpy.random.RandomState. In this example, I use an integer seed:

     In [46]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
    
     In [47]: mn.rvs(size=5, random_state=12345)
     Out[47]: 
     array([[-0.51943872,  1.07094986, -1.0235383 ],
            [ 1.39340583,  4.39561899, -2.77865152],
            [ 0.76902257,  0.63000355,  0.46453938],
            [-1.29622111,  2.25214387,  6.23217368],
            [ 1.35291684,  0.51186476,  1.37495817]])
    
     In [48]: mn.rvs(size=5, random_state=12345)
     Out[48]: 
     array([[-0.51943872,  1.07094986, -1.0235383 ],
            [ 1.39340583,  4.39561899, -2.77865152],
            [ 0.76902257,  0.63000355,  0.46453938],
            [-1.29622111,  2.25214387,  6.23217368],
            [ 1.35291684,  0.51186476,  1.37495817]])
    

    This version uses an instance of numpy.random.Generator:

    In [34]: rng = np.random.default_rng(438753948759384)
    
    In [35]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
    
    In [36]: mn.rvs(size=5, random_state=rng)
    Out[36]: 
    array([[ 0.30626179,  0.60742839,  2.86919105],
           [ 1.61859885,  2.63409111,  1.19018398],
           [ 0.35469027,  0.85685011,  6.76892829],
           [-0.88659459, -0.59922575, -5.43926698],
           [ 0.94777687, -5.80057427, -2.16887719]])
    
  2. You can set the seed for numpy's global random number generator. This is the generator that multivariate_normal.rvs() uses if random_state is not given:

     In [54]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
    
     In [55]: np.random.seed(123)
    
     In [56]: mn.rvs(size=5)
     Out[56]: 
     array([[  0.2829785 ,   2.23013222,  -5.42815302],
            [  1.65143654,  -1.2937895 ,  -7.53147357],
            [  1.26593626,  -0.95907779, -12.13339622],
            [ -0.09470897,  -1.51803558,  -4.33370201],
            [ -0.44398196,  -1.4286283 ,   7.45694813]])
    
     In [57]: np.random.seed(123)
    
     In [58]: mn.rvs(size=5)
     Out[58]: 
     array([[  0.2829785 ,   2.23013222,  -5.42815302],
            [  1.65143654,  -1.2937895 ,  -7.53147357],
            [  1.26593626,  -0.95907779, -12.13339622],
            [ -0.09470897,  -1.51803558,  -4.33370201],
            [ -0.44398196,  -1.4286283 ,   7.45694813]])
    
like image 74
Warren Weckesser Avatar answered Feb 07 '23 17:02

Warren Weckesser