Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Creating `input_fn` from iterator

Most tutorials focus on the case where the entire training dataset fits into memory. However, I have an iterator which acts as an infinite stream of (features, labels)-tuples (creating them cheaply on the fly).

When implementing the input_fn for tensorflows estimator, can I return an instance from the iterator as

def input_fn():
   (feature_batch, label_batch) = next(it)
   return tf.constant(feature_batch), tf.constant(label_batch)

or does input_fn has to return the same (features, labels)-tuples on each call?

Moreover is this function called multiple times during training as I hope it is like in the following pseudocode:

for i in range(max_iter):
   learn_op(input_fn())
like image 933
Manuel Schmidt Avatar asked May 03 '17 13:05

Manuel Schmidt


2 Answers

The argument of input_fn are used throughout training but the function itself is called once. So creating a sophisticated input_fn that goes beyond returning a constant array as explained in the tutorial is not as straightforward.

Tensorflow proposes two examples of such non-trivial input_fn for numpy and panda arrays, but they start from an array in memory, so this does not help you with your problem.

You could also have a look at their code by following the links above, to see how they implement an efficient non-trivial input_fn, but you may find that it requires more code that you would like.

If you are willing to use the less-high level interface of Tensorflow, things are IMHO simpler and more flexible. There is a tutorial that covers most needs and the proposed solutions are easy(-er) to implement.

In particular, if you already have an iterator that returns data as you described in your question, using placeholders (section "Feeding" in the previous link) should be straightforward.

like image 199
P-Gn Avatar answered Nov 08 '22 01:11

P-Gn


I found a pull request which converts a generator to an input_fn: https://github.com/tensorflow/tensorflow/pull/7045/files

The relevant part is

  def _generator_input_fn():
    """generator input function."""
    queue = feeding_functions.enqueue_data(
      x,
      queue_capacity,
      shuffle=shuffle,
      num_threads=num_threads,
      enqueue_size=batch_size,
      num_epochs=num_epochs)

    features = (queue.dequeue_many(batch_size) if num_epochs is None
                else queue.dequeue_up_to(batch_size))
    if not isinstance(features, list):
      features = [features]
    features = dict(zip(input_keys, features))
    if target_key is not None:
      if len(target_key) > 1:
        target = {key: features.pop(key) for key in target_key}
      else:
        target = features.pop(target_key[0])
      return features, target
    return features
  return _generator_input_fn
like image 27
Manuel Schmidt Avatar answered Nov 08 '22 03:11

Manuel Schmidt