Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get the default session from a tf.estimator?

I am trying the high API tf.estimator, but I find it hardly to get the session to debug some inter-result such as global step.

cls = tf.estimator.Estimator(
    model_fn=my_model,
    params={
        'feature_columns': fcs,
        'hidden_units': [10, 10],
        'n_classes': 3,
    })

The example from https://www.tensorflow.org/versions/master/get_started/custom_estimators

I have try sess = tf.get_default_session and with tf.Session() as sess, but can't get the defut session.

like image 823
imhuay Avatar asked Mar 15 '18 08:03

imhuay


1 Answers

The easiest thing would be to use tf.Print like:

...
global_step = tf.Print(global_step, [global_step], message='Value of global step")
...

You can replace global_step with any tensor you want printed. Then when you run the training it will print the values every time the tensor is evaluated.

Another, more complicated way is to export the model then load it back in using your own session (not the estimator api). Once you do this you can call session.run for any operation defined. You can get operations with tf.get_operation_by_name or tf.get_tensor_by_name. You can also feed whatever values you want as a input.

like image 150
Sorin Avatar answered Oct 18 '22 06:10

Sorin