Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: stop gradient after a certain layer

Tags:

keras

Suppose you have a Keras NN model, how can you stop the gradient in the backpropagation after a certain layer?


I.e., if we have a model with two outputs:

input_layer = Input(shape=(10,10,3))

x = Convolution2D(...)(input_layer)
x = Activation('relu')(x)

x = Flatten()(x)

x_1 = Dense(64)(x)
x_1 = Dense(32)(x_1)
x_1 = Dense(2)(x_1)

x_2 = Dense(64)(x)
x_2 = Dense(32)(x_2)
x_2 = Dense(2)(x_2)

model = Model(inputs=input_layer, outputs=[x_1, x_2])

How can the gradient of output x_1 be stopped after the x_1 = Dense(64)(x) layer, such that it doesn't get counted in for the weight update in the convolutional layer?


Based on the answer in Stopping Gradient back prop through a particular layer in keras, I'd have added a lambda layer before the x_1 dense layer, but I'm not really sure:

x_1 = Dense(64)(x)
x_1_stop_grad = Lambda(lambda x: K.stop_gradient(x))(x_1)
x_1 = Dense(32)(x_1)
x_1 = Dense(2)(x_1)

Do I have to add the lambda layer before or after the first dense x_1 layer?

like image 785
0vbb Avatar asked Feb 04 '23 23:02

0vbb


1 Answers

Since the gradient is flowing backwards through the network, you need to add the gradient stop layer directly after the layer, where no gradient should arrive.

I.e.

# weights in x should not be updated by gradients from x_1
x = Convolution2D(...)(input_layer) 
x_1_stop_grad = Lambda(lambda x: K.stop_gradient(x))(x)
x_1 = Dense(64)(x_1_stop_grad)
x_1 = Dense(32)(x_1)
...
like image 64
0vbb Avatar answered Mar 04 '23 11:03

0vbb