I'm trying to make a Tensorflow graph where part of the graph is already pre-trained and running in prediction mode, while the rest trains. I've defined my pre-trained cell like so:
rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)
state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]
outputs, states = tf.contrib.rnn.static_rnn(rnn_cell,
data_input,
dtype=tf.float32,
initial_state = pretrained_state)
Setting the initial variables to trainable=False
doesn't help. These are just used to initialize the weights and as a result the weights still change.
I still need to run an optimizer in my training step, since the rest of my model needs to train. But how can I prevent the optimizer from changing the weights in this rnn cell?
Is there a rnn_cell equivalent to trainable=False
?
You can use either tf.stop_gradient()
to prevent the pretrained
parts of the graph from updating its weights or you can use the optimiser()
where you can specify which parts of the graph should be trained. The second method would involve:
#Create variable scope for the trainable parts of the graph: tf.variable_scope('train').
# get trainable variables
t_vars = tf.trainable_variables()
train_vars = [var for var in t_vars if var.name.startswith('train')]
# train only the variables of a particular scope
opt = optimizer.minimize(cost, var_list=train_vars)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With