Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to Properly Combine TensorFlow's Dataset API and Keras?

Keras' fit_generator() model method expects a generator which produces tuples of the shape (input, targets), where both elements are NumPy arrays. The documentation seems to imply that if I simply wrap a Dataset iterator in a generator, and make sure to convert the Tensors to NumPy arrays, I should be good to go. This code, however, gives me an error:

import numpy as np import os import keras.backend as K from keras.layers import Dense, Input from keras.models import Model import tensorflow as tf from tensorflow.contrib.data import Dataset  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  with tf.Session() as sess:     def create_data_generator():         dat1 = np.arange(4).reshape(-1, 1)         ds1 = Dataset.from_tensor_slices(dat1).repeat()          dat2 = np.arange(5, 9).reshape(-1, 1)         ds2 = Dataset.from_tensor_slices(dat2).repeat()          ds = Dataset.zip((ds1, ds2)).batch(4)         iterator = ds.make_one_shot_iterator()         while True:             next_val = iterator.get_next()             yield sess.run(next_val)  datagen = create_data_generator()  input_vals = Input(shape=(1,)) output = Dense(1, activation='relu')(input_vals) model = Model(inputs=input_vals, outputs=output) model.compile('rmsprop', 'mean_squared_error') model.fit_generator(datagen, steps_per_epoch=1, epochs=5,                     verbose=2, max_queue_size=2) 

Here's the error I get:

Using TensorFlow backend. Epoch 1/5 Exception in thread Thread-1: Traceback (most recent call last):   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__     fetch, allow_tensor=True, allow_operation=True))   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element     return self._as_graph_element_locked(obj, allow_tensor, allow_operation)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked     raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.  During handling of the above exception, another exception occurred:  Traceback (most recent call last):   File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner     self.run()   File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run     self._target(*self._args, **self._kwargs)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task     generator_output = next(self._generator)   File "./datagen_test.py", line 25, in create_data_generator     yield sess.run(next_val)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run     run_metadata_ptr)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run     self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__     self._fetch_mapper = _FetchMapper.for_fetch(fetches)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch     return _ListFetchMapper(fetch)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch     return _ElementFetchMapper(fetches, contraction_fn)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__     'Tensor. (%s)' % (fetch, str(e))) ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)  Traceback (most recent call last):   File "./datagen_test.py", line 34, in <module>     verbose=2, max_queue_size=2)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper     return func(*args, **kwargs)   File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator     generator_output = next(output_generator) StopIteration 

Strangely enough, adding a line containing next(datagen) directly after where I initialize datagen causes the code to run just fine, with no errors.

Why does my original code not work? Why does it begin to work when I add that line to my code? Is there a more efficient way to use TensorFlow's Dataset API with Keras that doesn't involve converting Tensors to NumPy arrays and back again?

like image 503
Jason Avatar asked Sep 09 '17 22:09

Jason


1 Answers

Update June 09, 2018

  • Starting from Tensorflow 1.9, one can pass tf.data.Dataset object directly into keras.Model.fit() and it would act similar to fit_generator.
  • A complete example can be found on this gist.
# Load mnist training data (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() training_set = tfdata_generator(x_train, y_train,is_training=True)  model = # your keras model here               model.fit(     training_set.make_one_shot_iterator(),     steps_per_epoch=len(x_train) // 128,     epochs=5,     verbose = 1) 
  • tfdata_generator is a function that returns an iterable tf.data.Dataset.
def tfdata_generator(images, labels, is_training, batch_size=128):   '''Construct a data generator using `tf.Dataset`. '''    def map_fn(image, label):       '''Preprocess raw data to trainable input. '''     x = tf.reshape(tf.cast(image, tf.float32), (28, 28, 1))     y = tf.one_hot(tf.cast(label, tf.uint8), _NUM_CLASSES)     return x, y    dataset = tf.data.Dataset.from_tensor_slices((images, labels))    if is_training:     dataset = dataset.shuffle(1000)  # depends on sample size   dataset = dataset.map(map_fn)   dataset = dataset.batch(batch_size)   dataset = dataset.repeat()   dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)    return dataset 

Old Solution:

In addition to @Yu-Yang's answer, you can also modify tf.data.Dataset to become a generator for fit_generator as following

from tensorflow.contrib.learn.python.learn.datasets import mnist  data   = mnist.load_mnist() model  = # your Keras model model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels),                     steps_per_epoch=200,                     workers = 0 , # This is important                     verbose = 1)   def tfdata_generator(images, labels, batch_size=128, shuffle=True,):     def map_func(image, label):         '''A transformation function'''         x_train = tf.reshape(tf.cast(image, tf.float32), image_shape)         y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes)         return [x_train, y_train]      dataset  = tf.data.Dataset.from_tensor_slices((images, labels))     dataset  = dataset.map(map_func)     dataset  = dataset.shuffle().batch(batch_size).repeat()     iterator = dataset.make_one_shot_iterator()      next_batch = iterator.get_next()     while True:         yield K.get_session().run(next_batch) 
like image 169
Dat Avatar answered Sep 18 '22 13:09

Dat