When I sample from a distribution in PyTorch, both sample
and rsample
appear to give similar results:
import torch, seaborn as sns
x = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
sns.distplot(x.sample((100000,))) |
sns.distplot(x.rsample((100000,))) |
When should I use sample()
, and when should I use rsample()
?
Using rsample
allows for pathwise derivatives:
The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the
rsample()
method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With