Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to convert a Python data generator to a Tensorflow tensor?

I have a data generator that I am producing training images from. I'd like to feed the data into the Tensorflow model by using this Python data generator, but I can't figure out how to convert the generator to a Tensorflow tensor. I'm looking for something similar to Keras' fit_generator() function.

Thanks!

like image 306
user135237 Avatar asked Dec 24 '22 11:12

user135237


1 Answers

The tf.data.Dataset.from_generator() method provides a way to convert Python generators into tf.Tensor objects that evaluate to each successive element from the generator.

Let's say you have a simple generator that generates tuples (but could alternatively generate lists or NumPy arrays):

def g():
  yield 1, 10.0, "foo"
  yield 2, 20.0, "bar"
  yield 3, 30.0, "baz"

You can use the tf.data API to convert the generator first to a tf.data.Dataset, then to a tf.data.Iterator, and finally to a tuple of tf.Tensor objects.

dataset = tf.data.Dataset.from_generator(g, (tf.int32, tf.float32, tf.string))

iterator = dataset.make_one_shot_iterator()

int_tensor, float_tensor, str_tensor = iterator.get_next()

You can then use int_tensor, float_tensor, and str_tensor as the inputs to your TensorFlow model. See the tf.data programmer's guide for more ideas.

like image 63
mrry Avatar answered Dec 28 '22 06:12

mrry