Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get the Q-values in DQN in stable baseline 3?

I have an observation space in the format of Box but is actually defined as numpy array.

For example:

Box(low=np.array([0, 0, 0]), high=np.array([15, 10,150]))

Now I want to get the q_value for a single observation, but since the observation is Box the code of the stable baseline 3 is:

if isinstance(observation_space, spaces.Box):
    return obs.float()

But, the input observation does not have float attribute, So in this case how can I access the q_values of all the actions?

like image 865
naizz Avatar asked Oct 17 '25 08:10

naizz


1 Answers

So, I figured out how to resolve it. Will post it here in case it's someone else's problem too.

observation = obs.reshape((-1,) + model.observation_space.shape)
observation = obs_as_tensor(observation, device)
with th.no_grad():
    q_values = model.q_net(observation)
like image 187
naizz Avatar answered Oct 18 '25 23:10

naizz



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!