I have a data generator that I am producing training images from. I'd like to feed the data into the Tensorflow model by using this Python data generator, but I can't figure out how to convert the generator to a Tensorflow tensor. I'm looking for something similar to Keras' fit_generator() function.
Thanks!
The tf.data.Dataset.from_generator()
method provides a way to convert Python generators into tf.Tensor
objects that evaluate to each successive element from the generator.
Let's say you have a simple generator that generates tuples (but could alternatively generate lists or NumPy arrays):
def g():
yield 1, 10.0, "foo"
yield 2, 20.0, "bar"
yield 3, 30.0, "baz"
You can use the tf.data
API to convert the generator first to a tf.data.Dataset
, then to a tf.data.Iterator
, and finally to a tuple of tf.Tensor
objects.
dataset = tf.data.Dataset.from_generator(g, (tf.int32, tf.float32, tf.string))
iterator = dataset.make_one_shot_iterator()
int_tensor, float_tensor, str_tensor = iterator.get_next()
You can then use int_tensor
, float_tensor
, and str_tensor
as the inputs to your TensorFlow model. See the tf.data
programmer's guide for more ideas.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With