Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get a normal distribution within a range in numpy? [duplicate]

In machine learning task. We should get a group of random w.r.t normal distribution with bound. We can get a normal distribution number with np.random.normal() but it does't offer any bound parameter. I want to know how to do that?

like image 897
maple Avatar asked Apr 27 '16 15:04

maple


People also ask

What is NP random normal () used for?

normal, the Numpy random normal function allows us to create normally distributed data, while specifying important parameters like the mean and standard deviation.

What does Numpy Randn do?

The numpy. random. randn() function creates an array of specified shape and fills it with random values as per standard normal distribution.


1 Answers

The parametrization of truncnorm is complicated, so here is a function that translates the parametrization to something more intuitive:

from scipy.stats import truncnorm  def get_truncated_normal(mean=0, sd=1, low=0, upp=10):     return truncnorm(         (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd) 


How to use it?

  1. Instance the generator with the parameters: mean, standard deviation, and truncation range:

    >>> X = get_truncated_normal(mean=8, sd=2, low=1, upp=10) 
  2. Then, you can use X to generate a value:

    >>> X.rvs() 6.0491227353928894 
  3. Or, a numpy array with N generated values:

    >>> X.rvs(10) array([ 7.70231607,  6.7005871 ,  7.15203887,  6.06768994,  7.25153472,         5.41384242,  7.75200702,  5.5725888 ,  7.38512757,  7.47567455]) 

A Visual Example

Here is the plot of three different truncated normal distributions:

X1 = get_truncated_normal(mean=2, sd=1, low=1, upp=10) X2 = get_truncated_normal(mean=5.5, sd=1, low=1, upp=10) X3 = get_truncated_normal(mean=8, sd=1, low=1, upp=10)  import matplotlib.pyplot as plt fig, ax = plt.subplots(3, sharex=True) ax[0].hist(X1.rvs(10000), normed=True) ax[1].hist(X2.rvs(10000), normed=True) ax[2].hist(X3.rvs(10000), normed=True) plt.show() 

enter image description here

like image 198
toto_tico Avatar answered Sep 29 '22 11:09

toto_tico