Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can tf.agent policy return probability vector for all actions?

I am trying to train a Reinforcement Learning agent using TF-Agent TF-Agent DQN Tutorial. In my application, I have 1 action containing 9 possible discrete values (labeled from 0 to 8). Below is the output from env.action_spec()

BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(8, dtype=int64))

I would like to get the probability vector contains all actions calculated by the trained policy, and do further processing in other application environments. However, the policy only returns log_probability with a single value rather than a vector for all actions. Is there anyway to get the probability vector?

from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent

q_net = q_network.QNetwork(
            env.observation_spec(),
            env.action_spec(),
            fc_layer_params=(32,)
        )

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)

my_agent = dqn_agent.DqnAgent(
    env.time_step_spec(),
    env.action_spec(),
    q_network=q_net,
    epsilon_greedy=epsilon,
    optimizer=optimizer,
    emit_log_probability=True,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=global_step)

my_agent.initialize()

...  # training

tf_policy_saver = policy_saver.PolicySaver(my_agent.policy)
tf_policy_saver.save('./policy_dir/')

# making decision using the trained policy
action_step = my_agent.policy.action(time_step)

In dqn_agent.DqnAgent() DQNAgent, I set emit_log_probability=True, which is supposed to define Whether policies emit log probabilities or not.

However, when I run action_step = my_agent.policy.action(time_step), it returns

PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1], dtype=int64)>, state=(), info=PolicyInfo(log_probability=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>))

I also tried to run action_distribution = saved_policy.distribution(time_step), It returns

PolicyStep(action=<tfp.distributions.DeterministicWithLogProbCT 'Deterministic' batch_shape=[1] event_shape=[] dtype=int64>, state=(), info=PolicyInfo(log_probability=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>))

If there is no such API available in TF.Agent, is there a way to get such probability vector? Thanks.


Follow-up Question:

If I understand correctly, deep Q-network is supposed to get inputs of the state and output the Q-value for each action from the state. I could pass this Q-value vector into a softmax function and calculate the corresponding probability vector. Actually I have done such calculation with my own customized DQN script (without TF-Agent). Then the question becomes: how to return the Q-value vector from TF-Agent?

like image 872
BING ZHAO Avatar asked Aug 24 '20 06:08

BING ZHAO


1 Answers

The only way to do this in the TF-Agents framework is to invoke the Policy.distribution() method instead of the action method. This would return the original distribution that was computed out of the Q-values of the network. The emit_log_probability=True only affects the info attribute of the PolicyStep namedtuple that Policy.action() returns. Note that this distribution is possibly affected by the action constraints that you pass (if you do); whereby illegal actions will be marked as having 0 probability (even though there original Q-value might have been high).

If furthermore you would like to see the actual Q-values instead of the distribution that they generate, then I'm afraid there is no way of doing this without acting directly upon the Q-network that comes with your agent (and that is also attached to the Policy object that the agent produces). If you want to see how to call that Q-network properly I recommend looking at how the QPolicy._distribution() method does it here.

Note that none of this can be done using the pre-implemented Drivers. You would have to either explicitly construct your own collection loop or implement your own Driver object (which is basically equivalent).

like image 70
Federico Malerba Avatar answered Nov 13 '22 19:11

Federico Malerba