Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to provide custom gradient in TensorFlow

I am trying to understand that how to use @tf.custom_gradient function available in TensorFlow 1.7 for providing a custom gradient of a vector with respect to a vector. Below code is the minimum working example which solves following problem to get dz/dx.

y=Ax
z=||y||2

Also, this attached image describes the solution as expected by manually calulation

If I do not use the @tf.custom_gradient then the TensorFlow gives the desired solution as expected. My question is that how can I provide custom gradient for y=Ax? We know that dy/dx = A^T as shown in the above attachment which shows steps of calculation that matches the TensorFlow output.

import tensorflow as tf

#I want to write custom gradient for this function f1
def f1(A,x):
    y=tf.matmul(A,x,name='y')
    return y

#for y= Ax, the derivative is: dy/dx= transpose(A)
@tf.custom_gradient
def f2(A,x):
    y=f1(A,x)
    def grad(dzByDy): # dz/dy = 2y reaches here correctly.
        dzByDx=tf.matmul(A,dzByDy,transpose_a=True) 
        return dzByDx
    return y,grad


x= tf.constant([[1.],[0.]],name='x')
A= tf.constant([ [1., 2.], [3., 4.]],name='A')

y=f1(A,x) # This works as desired
#y=f2(A,x) #This line gives Error


z=tf.reduce_sum(y*y,name='z')

g=tf.gradients(ys=z,xs=x)

with tf.Session() as sess:
    print sess.run(g)
like image 841
Hemant Avatar asked Oct 17 '22 19:10

Hemant


1 Answers

Since your function f2() has two inputs, you have to provide a gradient to flow back to each of them. The error you see:

Num gradients 2 generated for op name: "IdentityN" [...] do not match num inputs 3

is admittedly quite cryptic, though. Supposing you never want to calculate dy/dA, you can just return None, dzByDx. The code below (tested):

import tensorflow as tf

#I want to write custom gradient for this function f1
def f1(A,x):
    y=tf.matmul(A,x,name='y')
    return y

#for y= Ax, the derivative is: dy/dx= transpose(A)
@tf.custom_gradient
def f2(A,x):
    y=f1(A,x)
    def grad(dzByDy): # dz/dy = 2y reaches here correctly.
        dzByDx=tf.matmul(A,dzByDy,transpose_a=True) 
        return None, dzByDx
    return y,grad

x= tf.constant([[1.],[0.]],name='x')
A= tf.constant([ [1., 2.], [3., 4.]],name='A')

#y=f1(A,x) # This works as desired
y=f2(A,x) #This line gives Error

z=tf.reduce_sum(y*y,name='z')

g=tf.gradients(ys=z,xs=x)

with tf.Session() as sess:
    print sess.run( g )

outputs:

[array([[20.], [28.]], dtype=float32)]

as desired.

like image 195
Peter Szoldan Avatar answered Oct 21 '22 04:10

Peter Szoldan