Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

`get_variable()` doesn't recognize existing variables for tf.estimator

This question has been asked here, difference is my problem is focused on Estimator.

Some context: We have trained a model using estimator and get some variable defined within Estimator input_fn, this function preprocesses data to batches. Now, we are moving to prediction. During the prediction, we use the same input_fn to read in and process the data. But got error saying variable (word_embeddings) does not exist (variables exist in the chkp graph), here's the relevant bit of code in input_fn:

with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
    if mode == tf.estimator.ModeKeys.TRAIN:
        word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
        word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
                                          trainable=False,
                                          name="word_to_vec",
                                          dtype=tf.float32)
    else:
        word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)

basically, when it's in prediction mode, else is invoked to load up variables in checkpoint. Failure of recognizing this variable indicates a) inappropriate usage of scope; b) graph is not restored. I don't think scope matters that much here as long as reuse is set properly.

I suspect that is because the graph is not yet restored at input_fn phase. Usually, the graph is restored by calling saver.restore(sess, "/tmp/model.ckpt") reference. Investigation of estimator source code doesn't get me anything relating to restore, the best shot is MonitoredSession, a wrapper of training. It's already been stretch so much from the original problem, not confident if I'm on the right path, I'm looking for help here if anyone has any insights.

One line summary of my question: How does graph get restored within tf.estimator, via input_fn or model_fn?

like image 900
GabrielChu Avatar asked Nov 26 '18 11:11

GabrielChu


People also ask

What is the difference between TF variable and TF get_variable?

As far as I know, Variable is the default operation for making a variable, and get_variable is mainly used for weight sharing. On the one hand, there are some people suggesting using get_variable instead of the primitive Variable operation whenever you need a variable.

What is TF get_variable?

The function tf. get_variable() returns the existing variable with the same name if it exists, and creates the variable with the specified shape and initializer if it does not exist.

What is TF estimator?

TF Lattice Custom Estimators. Graph-based Neural Structured Learning in TFX. The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)


1 Answers

Hi I think that you error comes simply because you didn't specify the shape in the tf.get_variable (at predict) , it seems that you need to specify the shape even if the variable is going to be restored.

I've made the following test with a simple linear regressor estimator that simply needs to predict x + 5

def input_fn(mode):
    def _input_fn():
        with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
            if mode == tf.estimator.ModeKeys.TRAIN:
                var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
                x_data = np.random.randn(1000)
                labels = x_data + 5
                return {'x':x_data}, labels
            elif mode == tf.estimator.ModeKeys.PREDICT:
                var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
                return {'x':[0,10,100,var_to_follow]}
    return _input_fn

featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')

This code works perfectly fine, the value of the const is 20 and also for fun use it in my test set to confirm :p

However if you remove the shape=[] , it breaks, you can also give another initializer such as tf.constant(500) and everything will work and 20 will be used.

By running

model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)

and

preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))

You can visualize the graph and you'll see that a) the scoping is normal and b) the graph is restored.

Hope this will help you.

like image 120
abcdaire Avatar answered Sep 24 '22 01:09

abcdaire