I am using Scipy.stats.multivariate_normal to draw samples from a multivariate normal distribution. Like this:
from scipy.stats import multivariate_normal
# Assume we have means and covs
mn = multivariate_normal(mean = means, cov = covs)
# Generate some samples
samples = mn.rvs()
The samples are different at every run. How do I get always the same sample? I was expecting something like:
mn = multivariate_normal(mean = means, cov = covs, seed = aNumber)
or
samples = mn.rsv(seed = aNumber)
There are two ways:
The rvs()
method accepts a random_state
argument. Its value can
be an integer seed, or an instance of numpy.random.Generator
or numpy.random.RandomState
. In
this example, I use an integer seed:
In [46]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
In [47]: mn.rvs(size=5, random_state=12345)
Out[47]:
array([[-0.51943872, 1.07094986, -1.0235383 ],
[ 1.39340583, 4.39561899, -2.77865152],
[ 0.76902257, 0.63000355, 0.46453938],
[-1.29622111, 2.25214387, 6.23217368],
[ 1.35291684, 0.51186476, 1.37495817]])
In [48]: mn.rvs(size=5, random_state=12345)
Out[48]:
array([[-0.51943872, 1.07094986, -1.0235383 ],
[ 1.39340583, 4.39561899, -2.77865152],
[ 0.76902257, 0.63000355, 0.46453938],
[-1.29622111, 2.25214387, 6.23217368],
[ 1.35291684, 0.51186476, 1.37495817]])
This version uses an instance of numpy.random.Generator
:
In [34]: rng = np.random.default_rng(438753948759384)
In [35]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
In [36]: mn.rvs(size=5, random_state=rng)
Out[36]:
array([[ 0.30626179, 0.60742839, 2.86919105],
[ 1.61859885, 2.63409111, 1.19018398],
[ 0.35469027, 0.85685011, 6.76892829],
[-0.88659459, -0.59922575, -5.43926698],
[ 0.94777687, -5.80057427, -2.16887719]])
You can set the seed for numpy's global random number generator. This is the generator that multivariate_normal.rvs()
uses if random_state
is not given:
In [54]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
In [55]: np.random.seed(123)
In [56]: mn.rvs(size=5)
Out[56]:
array([[ 0.2829785 , 2.23013222, -5.42815302],
[ 1.65143654, -1.2937895 , -7.53147357],
[ 1.26593626, -0.95907779, -12.13339622],
[ -0.09470897, -1.51803558, -4.33370201],
[ -0.44398196, -1.4286283 , 7.45694813]])
In [57]: np.random.seed(123)
In [58]: mn.rvs(size=5)
Out[58]:
array([[ 0.2829785 , 2.23013222, -5.42815302],
[ 1.65143654, -1.2937895 , -7.53147357],
[ 1.26593626, -0.95907779, -12.13339622],
[ -0.09470897, -1.51803558, -4.33370201],
[ -0.44398196, -1.4286283 , 7.45694813]])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With