Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the difference between sample() and rsample()?

Tags:

When I sample from a distribution in PyTorch, both sample and rsample appear to give similar results:

import torch, seaborn as sns

x = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
enter image description here enter image description here
sns.distplot(x.sample((100000,))) sns.distplot(x.rsample((100000,)))

When should I use sample(), and when should I use rsample()?

like image 1000
apostofes Avatar asked Mar 04 '20 19:03

apostofes


1 Answers

Using rsample allows for pathwise derivatives:

The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the rsample() method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable.

like image 58
Shai Avatar answered Nov 20 '22 18:11

Shai