Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make a Truncated normal distribution in pytorch?

I want to create a Truncated normal distribution(that is Gaussian distribution with a range) in PyTorch.
I want to be able to change the mean, std, and range.
Is there a PyTorch method for that?

like image 937
Lupos Avatar asked Nov 06 '22 09:11

Lupos


1 Answers

Use torch.nn.init.trunc_normal_.

Description as given Here:

Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2) with values outside :math:[a, b] redrawn until they are within the bounds. The method used for generating the random values works best when :math:a \leq \text{mean} \leq b.

like image 116
Sachin Yadav Avatar answered Dec 02 '22 21:12

Sachin Yadav