Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Taking sample from Categorical distribution pytorch

I'm currently working on a Deep reinforcement learning problem, and I'm using the categorical distribution to help the agent get random action. This is the code.

  def choose_action(self,enc_current_node,goal_node):
      #print('nn')
      #vector=self.convert_vector(observation,end)
      state=T.tensor([[enc_current_node,goal_node]],dtype=T.float)
      pi,v=self.forward(state)
      probs=T.softmax(pi,dim=1)
      print(probs)
      dist=Categorical(probs)
      action=dist.sample().numpy()[0]#take a sample from the categorical dist from 1-22
      return action

the output of the Categorical(props) like this:

probs=T.tensor([[1.5857e-03, 8.9753e-01, 2.8500e-03, 9.0585e-03, 3.6661e-04, 6.8342e-08,
         7.2956e-04, 3.3966e-05, 3.7150e-04, 1.8317e-05, 4.1543e-04, 4.7550e-05,
         5.2323e-05, 1.1337e-03, 1.6356e-05, 6.9848e-03, 2.2993e-03, 1.0874e-06,
         2.0343e-04, 2.3616e-03, 1.3477e-02, 6.1464e-02]])
c=Categorical(probs)
c
>>> output:
>>> Categorical(probs: torch.Size([1, 22]))

now in the function, I sued dist.sample to take a sample of the 22 elements in it but I notice something that a lot of time the sample method that used in PyTorch result in the same number 90% of the time as you can see here:

list=[]
for i in range(100):
  list.append(c.sample()[0].item())
>>> output:
>>> [1, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 21, 1, 3, 1, 1, 1, 1, 1, 1, 21, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 21, 1, 3, 1, 1, 1, 1, 1, 18, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3]

as you can see above, the sample method output a lot of 1 my question is there a way to choose a random sample from the Categorical distribution rather than this?

like image 615
noob Avatar asked Sep 18 '25 03:09

noob


1 Answers

If you look at your probabilities for sampling probs, you see that the 1th class has the largest probability, and almost all others are < 1%. If you are not familiar with scientific notation, here it is formatted as rounded percentages:

for label, p in enumerate(probs[0]):
    print(f'{label:2}: {100*p:5.2f}%')
 0:  0.16%
 1: 89.75%  <---
 2:  0.28%
 3:  0.91%
 4:  0.04%
 5:  0.00%
 6:  0.07%
 7:  0.00%
 8:  0.04%
 9:  0.00%
10:  0.04%
11:  0.00%
12:  0.01%
13:  0.11%
14:  0.00%
15:  0.70%
16:  0.23%
17:  0.00%
18:  0.02%
19:  0.24%
20:  1.35%
21:  6.15%

Hence ~90% of samples drawn from this will be 1.

like image 91
iacob Avatar answered Sep 20 '25 17:09

iacob