After read the API DOC, I also can't understand the usage of SessionRunHook. For example, what is the sequence of SessionRunHook's member
function to be called? Is it after_create_session -> before_run -> begin -> after_run -> end
?
And I can't find the tutorial with detailed examples, is there more detailed explanation?
You can find a tutorial here, a little long but you can jump the part of building the network. Or you can read my small summary below, based on my experiance.
First, MonitoredSession
should be used instead of normal Session
.
A SessionRunHook extends
session.run()
calls for theMonitoredSession
.
Then some common SessionRunHook
classes can be found here. A simple one is LoggingTensorHook
but you might want to add the following line after your imports for seeing the logs when running:
tf.logging.set_verbosity(tf.logging.INFO)
Or you have option to implement your own SessionRunHook
class. A simple one is from cifar10 tutorial
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
where loss
is defined outside the class. This _LoggerHook
uses print
to print the information while LoggingTensorHook
uses tf.logging.INFO
.
At last, for better understanding how it works, the execution order is presented by pseudocode with MonitoredSession
here:
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested: # py code: while not mon_sess.should_stop():
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
Hope this helps.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With