Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Keras generator with tf.data API

I am trying to use the generator found in Keras preprocessing library. I wanted to experiment with this since Keras provides great functions for image augmentation. However, I am not sure if this is actually possible.

Here is how I made a tf dataset from the Keras generator:

def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32)).shuffle(64).repeat().batch(32)

Note that if you try to directly give train_generator as the argument to tf.data.Dataset.from_generator there will be an error. However, the above method doesn't produce an error.

When I run it within a session to check the output from the dataset I get the following error.

iterator = train_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for i in range(100):
    sess.run(next_element)

Found 1000 images belonging to 2 classes. --------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1291 try: -> 1292 return fn(*args) 1293 except errors.OpError as e:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata) 1276 return self._call_tf_sessionrun( -> 1277 options, feed_dict, fetch_list, target_list, run_metadata) 1278

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata) 1366 self._session, options, feed_dict, fetch_list, target_list, -> 1367 run_metadata) 1368

InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [32,224,224,3] and element 29 had shape [8,224,224,3]. [[{{node IteratorGetNext_2}} = IteratorGetNextoutput_shapes=[, ], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

During handling of the above exception, another exception occurred:

Please let me know if anyone has any experience with this or know any alternate way.

UPDATE

I was able to solve the problem after using the suggestion by J.E.K.

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))

However when I give train_dataset to a Keras .fit method I get the following error.

model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)

--------------------------------------------------------------------------- ValueError Traceback (most recent call last) in () ----> 1 model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs) 1507 steps_name='steps_per_epoch', 1508 steps=steps_per_epoch, -> 1509 validation_split=validation_split) 1510 1511 # Prepare validation data.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split) 948 x = self._dataset_iterator_cache[x] 949 else: --> 950 iterator = x.make_initializable_iterator() 951 self._dataset_iterator_cache[x] = iterator 952 x = iterator

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in make_initializable_iterator(self, shared_name) 119 with ops.colocate_with(iterator_resource): 120 initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), --> 121 iterator_resource) 122 return iterator_ops.Iterator(iterator_resource, initializer, 123 self.output_types, self.output_shapes,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_dataset_ops.py in make_iterator(dataset, iterator, name) 2542 if _ctx is None or not _ctx._eager_context.is_eager: 2543 _, _, _op = _op_def_lib._apply_op_helper( -> 2544 "MakeIterator", dataset=dataset, iterator=iterator, name=name) 2545 return _op 2546 _result = None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords) 348 # Need to flatten all the arguments into a list. 349 # pylint: disable=protected-access --> 350 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 351 # pylint: enable=protected-access 352 except AssertionError as e:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _get_graph_from_inputs(op_input_list, graph) 5659 graph = graph_element.graph 5660 elif original_graph_element is not None: -> 5661 _assert_same_graph(original_graph_element, graph_element) 5662 elif graph_element.graph is not graph:
5663 raise ValueError("%s is not from the passed-in graph." % graph_element)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item, item) 5595 if original_item.graph is not item.graph: 5596 raise ValueError("%s must be from the same graph as %s." % (item, -> 5597 original_item)) 5598 5599

ValueError: Tensor("IteratorV2:0", shape=(), dtype=resource) must be from the same graph as Tensor("FlatMapDataset:0", shape=(), dtype=variant).

Is this a bug or is Keras fit method not meant to be used this way?

like image 740
siby Avatar asked Oct 03 '18 21:10

siby


People also ask

Is tf data dataset a generator?

data. Dataset objects as generators for the training of a machine learning model on Tensorflow, with parallelized processing. The tf. data pipeline is now the gold standard for building an efficient data pipeline for machine learning applications with TensorFlow.

What is the role of the tf data API in TensorFlow?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.


1 Answers

I have tried to reproduce your results with a simple example and I found out that you get different output shapes when one uses batching within the generator function and tf.data.

The Keras function train_datagen.flow_from_directory(batch_size=32) already returns the data with shape [batch_size, width, height, depth]. If one uses tf.data.Dataset().batch(32) the output data is batched again into shape [batch_size, batch_size, width, height, depth].

This could have caused your issue for some reason.

like image 61
J.E.K Avatar answered Sep 30 '22 06:09

J.E.K