I'm trying to do transfer learning of an Inception-resnet v2 model pretrained on imagenet, using my own dataset and classes.
My original codebase was a modification of a tf.slim
sample which I can't find anymore and now I'm trying to rewrite the same code using the tf.estimator.*
framework.
I am running, however, into the problem of loading only some of the weights from the pretrained checkpoint, initializing the remaining layers with their default initializers.
Researching the problem, I found this GitHub issue and this question, both mentioning the need to use tf.train.init_from_checkpoint
in my model_fn
. I tried, but given the lack of examples in both, I guess I got something wrong.
This is my minimal example:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np
import inception_resnet_v2
NUM_CLASSES = 900
IMAGE_SIZE = 299
def input_fn(mode, num_classes, batch_size=1):
# some code that loads images, reshapes them to 299x299x3 and batches them
return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)
def model_fn(images, labels, num_classes, mode):
with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2.inception_resnet_v2(images,
num_classes,
is_training=(mode==tf.estimator.ModeKeys.TRAIN))
predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
scopes = { os.path.dirname(v.name) for v in variables_to_restore }
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{s+'/':s+'/' for s in scopes})
tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)
train_op = optimizer.minimize(total_loss, global_step)
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=total_loss,
train_op=train_op)
def main(unused_argv):
# Create the Estimator
classifier = tf.estimator.Estimator(
model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
model_dir='model/MCVE')
# Train the model
classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),
steps=1000)
# Evaluate the model and print results
eval_results = classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))
print()
print('Evaluation results:\n %s' % eval_results)
if __name__ == '__main__':
tf.app.run(main=main, argv=[sys.argv[0]])
where inception_resnet_v2
is the model implementation in Tensorflow's models repository.
If I run this script, I get a bunch of info log from init_from_checkpoint
, but then, at session creation time, it seems it attempts to load the Logits
weights from the checkpoint and fails because of incompatible shapes. This is the full traceback:
Traceback (most recent call last):
File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>
runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
execfile(filename, namespace)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>
tf.app.run(main=main, argv=[sys.argv[0]])
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main
steps=1000)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__
_WrappedSession.__init__(self, self._create_session())
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session
return self._sess_creator.create_session()
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session
self.tf_sess = self._session_creator.create_session()
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session
init_fn=self._scaffold.init_fn)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session
sess.run(init_op, feed_dict=init_feed_dict)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT,
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]
What am I doing wrong when using init_from_checkpoint
? How exactly are we supposed to "use" it in our model_fn
? And why is the estimator trying to load the Logits
' weights from the checkpoint when I'm explicitly telling it not to?
After the suggestion in the comments, I tried alternative ways to call tf.train.init_from_checkpoint
.
{v.name: v.name}
If, as suggested in the comment, I replace the call with {v.name:v.name for v in variables_to_restore}
, I get this error:
ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.
{v.name: v}
If, instead, I try using the name:variable
mapping, I get the following error:
ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in
inception_resnet_v2_2016_08_30.ckpt checkpoint
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256],
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...
The error continues listing what I think are all the variable names in the checkpoint (or could it be the scopes instead?).
After inspecting the latest error here above, I see that InceptionResnetV2/Conv2d_2a_3x3/weights
is in the list of the checkpointed variables. The problem is that :0
at the end!
I'll now verify if this does indeed solve the problem and post an answer if that's the case.
The “model_fn” parameter is a function that consumes the features, labels, mode and params in the following order: def model_fn(features, labels, mode, params): The Estimator will always supply those parameters when it executes the model function for training, evaluation or prediction.
It is recommended using pre-made Estimators when just getting started. To write a TensorFlow program based on pre-made Estimators, you must perform the following tasks: Create one or more input functions. Define the model's feature columns.
The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)
Thanks to @KathyWu's comment, I got on the right track and found the problem.
Indeed, the way I was computing the scopes
would include the InceptionResnetV2/
scope, that would trigger the load of all variables "under" the scope (i.e., all variables in the network). Replacing this with the correct dictionary, however, was not trivial.
Of the possible scope modes init_from_checkpoint
accepts, the one I had to use was the 'scope_variable_name': variable
one, but without using the actual variable.name
attribute.
The variable.name
looks like: 'some_scope/variable_name:0'
. That :0
is not in the checkpointed variable's name and so using scopes = {v.name:v.name for v in variables_to_restore}
will raise a "Variable not found" error.
The trick to make it work was stripping the tensor index from the name:
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{v.name.split(':')[0]: v for v in variables_to_restore})
I find out {s+'/':s+'/' for s in scopes}
didn't work, just because the variables_to_restore
include something like "global_step"
, so scopes include the global scopes which could include everything. You need to print variables_to_restore
, find "global_step"
thing, and put it in "exclude"
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With