Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding log_prob for Normal distribution in pytorch

I'm currently trying to solve Pendulum-v0 from the openAi gym environment which has a continuous action space. As a result, I need to use a Normal Distribution to sample my actions. What I don't understand is the dimension of the log_prob when using it :

import torch
from torch.distributions import Normal 

means = torch.tensor([[0.0538],
        [0.0651]])
stds = torch.tensor([[0.7865],
        [0.7792]])

dist = Normal(means, stds)
a = torch.tensor([1.2,3.4])
d = dist.log_prob(a)
print(d.size())

I was expecting a tensor of size 2 (one log_prob for each actions) but it output a tensor of size(2,2).

However, when using a Categorical distribution for discrete environment the log_prob has the expected size:

logits = torch.tensor([[-0.0657, -0.0949],
        [-0.0586, -0.1007]])

dist = Categorical(logits = logits)
a = torch.tensor([1, 1])
print(dist.log_prob(a).size())

give me a tensor a size(2).

Why is the log_prob for Normal distribution of a different size ?

like image 585
Samuel Beaussant Avatar asked Mar 19 '20 20:03

Samuel Beaussant


People also ask

What is Log_prob in Pytorch?

log_prob(value)[source] Returns the log of the probability density/mass function evaluated at value . Parameters: value (Tensor) –

How do you create a normal distribution in Python Pytorch?

normal() method is used to create a tensor of random numbers. It will take two input parameters. the first parameter is the mean value and the second parameter is the standard deviation (std). We can specify the values for the mean and standard deviation directly or we can provide a tensor of elements.

What does torch multinomial do?

multinomial. Returns a tensor where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of tensor input .

What is Torch Full?

The full form of TORCH is toxoplasmosis, rubella cytomegalovirus, herpes simplex, and HIV. However, it can also contain other newborn infections. Sometimes the test is spelled TORCHS, where the extra "S" stands for syphilis.


1 Answers

If one takes a look in the source code of torch.distributions.Normal and finds the definition of the log_prob(value) function, one can see that the main part of the calculation is:

return -((value - self.loc) ** 2) / (2 * var) - some other part

where value is a variable containing values for which you want to calculate the log probability (in your case, a), self.loc is the mean of the distribution (in you case, means) and var is the variance, that is, the square of the standard deviation (in your case, stds**2). One can see that this is indeed the logarithm of the probability density function of the normal distribution, minus some constants and logarithm of the standard deviation that I don't write above.

In the first example, you define means and stds to be column vectors, while the values to be a row vector

means = torch.tensor([[0.0538],
    [0.0651]])
stds = torch.tensor([[0.7865],
    [0.7792]])
a = torch.tensor([1.2,3.4])

But subtracting a row vector from a column vector, that the code does in value - self.loc in Python gives a matrix (try!), thus the result you obtain is a value of log_prob for each of your two defined distribution and for each of the variables in a.

If you want to obtain a log_prob without the cross terms, then define the variables consistently, i.e., either

means = torch.tensor([[0.0538],
    [0.0651]])
stds = torch.tensor([[0.7865],
    [0.7792]])
a = torch.tensor([[1.2],[3.4]])

or

means = torch.tensor([0.0538,
    0.0651])
stds = torch.tensor([0.7865,
    0.7792])
a = torch.tensor([1.2,3.4])

This is how you do in your second example, which is why you obtain the result you expected.

like image 193
AndrisP Avatar answered Sep 28 '22 01:09

AndrisP