Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make the weights of an RNN cell untrainable in Tensorflow?

I'm trying to make a Tensorflow graph where part of the graph is already pre-trained and running in prediction mode, while the rest trains. I've defined my pre-trained cell like so:

rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)

state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]

outputs, states = tf.contrib.rnn.static_rnn(rnn_cell, 
                                            data_input,
                                            dtype=tf.float32,
                                            initial_state = pretrained_state)

Setting the initial variables to trainable=False doesn't help. These are just used to initialize the weights and as a result the weights still change.

I still need to run an optimizer in my training step, since the rest of my model needs to train. But how can I prevent the optimizer from changing the weights in this rnn cell?

Is there a rnn_cell equivalent to trainable=False?

like image 279
AlexR Avatar asked Jul 06 '17 21:07

AlexR


1 Answers

You can use either tf.stop_gradient() to prevent the pretrained parts of the graph from updating its weights or you can use the optimiser() where you can specify which parts of the graph should be trained. The second method would involve:

 #Create variable scope for the trainable parts of the graph: tf.variable_scope('train').

 # get trainable variables
 t_vars = tf.trainable_variables()
 train_vars = [var for var in t_vars if var.name.startswith('train')]
 # train only the variables of a particular scope
 opt = optimizer.minimize(cost, var_list=train_vars)
like image 109
vijay m Avatar answered Nov 10 '22 00:11

vijay m