Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow: how come gather_nd is differentiable?

I'm looking at a tensorflow network implementing reinforcement-learning for the CartPole open-ai env.

The network implements the likelihood ratio approach for a policy gradient agent.

The thing is, that the policy loss is defined using the gather_nd op!! here, look:

    ....
    self.y = tf.nn.softmax(tf.matmul(self.W3,self.h2) + self.b3,dim=0)
    self.curr_reward = tf.placeholder(shape=[None],dtype=tf.float32)
    self.actions_array = tf.placeholder(shape=[None,2],dtype=tf.int32)
    self.pai_array = tf.gather_nd(self.y,self.actions_array)
    self.L = -tf.reduce_mean(tf.log(self.pai_array)*self.curr_reward)

And then they take the derivative of this loss with respect to all the parameters of the network:

    self.gradients = tf.gradients(self.L,tf.trainable_variables())

How can this be?? I thought that the whole point in neural networks is always working with differentiable ops, like cross-entropy and never do something strange like selecting indexes of self.y according to some self.actions_array selected by random and clearly not differentiable.

What am I missing here? thanks!

like image 471
ihadanny Avatar asked Aug 15 '17 21:08

ihadanny


3 Answers

It's only differentiable w.r.t. self.y but not the integer/discrete elements of self.actions_array.

like image 186
Andrei Pokrovsky Avatar answered Oct 17 '22 11:10

Andrei Pokrovsky


The gradient is one if the parameter is gathered and zero if it is not. One use-case for the gather operator is to act like a sparse one-hot matrix multiplication. The second argument is the dense representation of the sparse matrix and you "multiply" it with the first argument by just selecting the right rows.

like image 33
Aaron Avatar answered Oct 17 '22 11:10

Aaron


There is no official documentation on this but according to this issue: https://github.com/tensorflow/models/issues/295 gradient of tf.gather in tensorflow implementation is 1 w.r.t to self.y and 0 w.r.t to index. Therefore, it will not propabagate gradient through index

like image 1
lam Avatar answered Oct 17 '22 10:10

lam