Most tutorials focus on the case where the entire training dataset fits into memory. However, I have an iterator which acts as an infinite stream of (features, labels)-tuples (creating them cheaply on the fly).
When implementing the input_fn
for tensorflows estimator, can I return an instance from the iterator as
def input_fn():
(feature_batch, label_batch) = next(it)
return tf.constant(feature_batch), tf.constant(label_batch)
or does input_fn
has to return the same (features, labels)-tuples on each call?
Moreover is this function called multiple times during training as I hope it is like in the following pseudocode:
for i in range(max_iter):
learn_op(input_fn())
The argument of input_fn
are used throughout training but the function itself is called once. So creating a sophisticated input_fn
that goes beyond returning a constant array as explained in the tutorial is not as straightforward.
Tensorflow proposes two examples of such non-trivial input_fn
for numpy and panda arrays, but they start from an array in memory, so this does not help you with your problem.
You could also have a look at their code by following the links above, to see how they implement an efficient non-trivial input_fn
, but you may find that it requires more code that you would like.
If you are willing to use the less-high level interface of Tensorflow, things are IMHO simpler and more flexible. There is a tutorial that covers most needs and the proposed solutions are easy(-er) to implement.
In particular, if you already have an iterator that returns data as you described in your question, using placeholders (section "Feeding" in the previous link) should be straightforward.
I found a pull request which converts a generator
to an input_fn
:
https://github.com/tensorflow/tensorflow/pull/7045/files
The relevant part is
def _generator_input_fn():
"""generator input function."""
queue = feeding_functions.enqueue_data(
x,
queue_capacity,
shuffle=shuffle,
num_threads=num_threads,
enqueue_size=batch_size,
num_epochs=num_epochs)
features = (queue.dequeue_many(batch_size) if num_epochs is None
else queue.dequeue_up_to(batch_size))
if not isinstance(features, list):
features = [features]
features = dict(zip(input_keys, features))
if target_key is not None:
if len(target_key) > 1:
target = {key: features.pop(key) for key in target_key}
else:
target = features.pop(target_key[0])
return features, target
return features
return _generator_input_fn
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With