Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the sequence of SessionRunHook's member function to be called? [closed]

Tags:

tensorflow

After read the API DOC, I also can't understand the usage of SessionRunHook. For example, what is the sequence of SessionRunHook's member function to be called? Is it after_create_session -> before_run -> begin -> after_run -> end ? And I can't find the tutorial with detailed examples, is there more detailed explanation?

like image 882
gaussclb Avatar asked Aug 06 '17 13:08

gaussclb


1 Answers

You can find a tutorial here, a little long but you can jump the part of building the network. Or you can read my small summary below, based on my experiance.

First, MonitoredSession should be used instead of normal Session.

A SessionRunHook extends session.run() calls for the MonitoredSession.

Then some common SessionRunHook classes can be found here. A simple one is LoggingTensorHook but you might want to add the following line after your imports for seeing the logs when running:

tf.logging.set_verbosity(tf.logging.INFO)

Or you have option to implement your own SessionRunHook class. A simple one is from cifar10 tutorial

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

where loss is defined outside the class. This _LoggerHook uses print to print the information while LoggingTensorHook uses tf.logging.INFO.

At last, for better understanding how it works, the execution order is presented by pseudocode with MonitoredSession here:

  call hooks.begin()
  sess = tf.Session()
  call hooks.after_create_session()
  while not stop is requested:  # py code: while not mon_sess.should_stop():
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()

Hope this helps.

like image 169
LI Xuhong Avatar answered Oct 08 '22 04:10

LI Xuhong