Keras' fit_generator()
model method expects a generator which produces tuples of the shape (input, targets), where both elements are NumPy arrays. The documentation seems to imply that if I simply wrap a Dataset
iterator in a generator, and make sure to convert the Tensors to NumPy arrays, I should be good to go. This code, however, gives me an error:
import numpy as np import os import keras.backend as K from keras.layers import Dense, Input from keras.models import Model import tensorflow as tf from tensorflow.contrib.data import Dataset os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' with tf.Session() as sess: def create_data_generator(): dat1 = np.arange(4).reshape(-1, 1) ds1 = Dataset.from_tensor_slices(dat1).repeat() dat2 = np.arange(5, 9).reshape(-1, 1) ds2 = Dataset.from_tensor_slices(dat2).repeat() ds = Dataset.zip((ds1, ds2)).batch(4) iterator = ds.make_one_shot_iterator() while True: next_val = iterator.get_next() yield sess.run(next_val) datagen = create_data_generator() input_vals = Input(shape=(1,)) output = Dense(1, activation='relu')(input_vals) model = Model(inputs=input_vals, outputs=output) model.compile('rmsprop', 'mean_squared_error') model.fit_generator(datagen, steps_per_epoch=1, epochs=5, verbose=2, max_queue_size=2)
Here's the error I get:
Using TensorFlow backend. Epoch 1/5 Exception in thread Thread-1: Traceback (most recent call last): File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__ fetch, allow_tensor=True, allow_operation=True)) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner self.run() File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run self._target(*self._args, **self._kwargs) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task generator_output = next(self._generator) File "./datagen_test.py", line 25, in create_data_generator yield sess.run(next_val) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run run_metadata_ptr) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__ self._fetch_mapper = _FetchMapper.for_fetch(fetches) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch return _ListFetchMapper(fetch) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__ self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp> self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch return _ElementFetchMapper(fetches, contraction_fn) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__ 'Tensor. (%s)' % (fetch, str(e))) ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.) Traceback (most recent call last): File "./datagen_test.py", line 34, in <module> verbose=2, max_queue_size=2) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper return func(*args, **kwargs) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator generator_output = next(output_generator) StopIteration
Strangely enough, adding a line containing next(datagen)
directly after where I initialize datagen
causes the code to run just fine, with no errors.
Why does my original code not work? Why does it begin to work when I add that line to my code? Is there a more efficient way to use TensorFlow's Dataset API with Keras that doesn't involve converting Tensors to NumPy arrays and back again?
tf.data.Dataset
object directly into keras.Model.fit()
and it would act similar to fit_generator
. # Load mnist training data (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() training_set = tfdata_generator(x_train, y_train,is_training=True) model = # your keras model here model.fit( training_set.make_one_shot_iterator(), steps_per_epoch=len(x_train) // 128, epochs=5, verbose = 1)
tfdata_generator
is a function that returns an iterable tf.data.Dataset
.def tfdata_generator(images, labels, is_training, batch_size=128): '''Construct a data generator using `tf.Dataset`. ''' def map_fn(image, label): '''Preprocess raw data to trainable input. ''' x = tf.reshape(tf.cast(image, tf.float32), (28, 28, 1)) y = tf.one_hot(tf.cast(label, tf.uint8), _NUM_CLASSES) return x, y dataset = tf.data.Dataset.from_tensor_slices((images, labels)) if is_training: dataset = dataset.shuffle(1000) # depends on sample size dataset = dataset.map(map_fn) dataset = dataset.batch(batch_size) dataset = dataset.repeat() dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) return dataset
In addition to @Yu-Yang's answer, you can also modify tf.data.Dataset
to become a generator for fit_generator
as following
from tensorflow.contrib.learn.python.learn.datasets import mnist data = mnist.load_mnist() model = # your Keras model model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels), steps_per_epoch=200, workers = 0 , # This is important verbose = 1) def tfdata_generator(images, labels, batch_size=128, shuffle=True,): def map_func(image, label): '''A transformation function''' x_train = tf.reshape(tf.cast(image, tf.float32), image_shape) y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes) return [x_train, y_train] dataset = tf.data.Dataset.from_tensor_slices((images, labels)) dataset = dataset.map(map_func) dataset = dataset.shuffle().batch(batch_size).repeat() iterator = dataset.make_one_shot_iterator() next_batch = iterator.get_next() while True: yield K.get_session().run(next_batch)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With