Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: Keras, Estimators and custom input function

Tags:

TF1.4 made Keras an integral part. When trying to create Estimators from Keras models with propratery input function (I.e., not using the tf.estimator.inputs.numpy_input_fn) things are not working as Tensorflow can not fuse the model with the Input function.

I am using tf.keras.estimator.model_to_estimator

keras_estimator = tf.keras.estimator.model_to_estimator(
            keras_model = keras_model,
            config = run_config)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, 
                                    max_steps=self.train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                  steps=None)

tf.estimator.train_and_evaluate(keras_estimator, train_spec, eval_spec)

and I get the following error message:

    Cannot find %s with name "%s" in Keras Model. It needs to match '
              'one of the following:

I found some reference for this topic here (strangely enough its hidden in the TF docs in the master branch - compare to this)

If you have the same issue - see my answer below. Might save you several hours.

like image 504
Shahar Karny Avatar asked Dec 22 '17 21:12

Shahar Karny


1 Answers

So here is the deal. You must make sure that your custom Input Function returns a dictionary of {inputs} and a dictionary of {outputs}. The dictionary keys must match your Keras input/output layers name.

From TF docs:

First, recover the input name(s) of Keras model, so we can use them as the feature column name(s) of the Estimator input function

This is correct. Here is how I did this:

# Get inputs and outout Keras model name to fuse them into the infrastructure.
keras_input_names_list = keras_model.input_names
keras_target_names_list = keras_model.output_names

Now, that you have the names, you need to go to your own input function and change it so it will deliver two dictionaries with the corresponding input and output names.

In my example, before the change, the input function returned [image_batch],[label_batch]. This is basically a bug because it is stated that the inputfn returns a dictionary and not a list.

To solve this, we need to wrap it up into a dict:

image_batch_dict = dict(zip(keras_input_names_list , [image_batch]))
label_batch_dict = dict(zip(keras_target_names_list , [label_batch]))

Only now, TF will be able to connect the input function to the Keras input layers.

like image 123
Shahar Karny Avatar answered Sep 23 '22 12:09

Shahar Karny