I'm trying to create a Dataset object in tensorflow 1.14 (I have some legacy code that i can't change for this specific project) starting from numpy arrays, but everytime i try i get everything copied on my graph and for this reason when i create an event log file it is huge (719 MB in this case).
Originally i tried using this function "tf.data.Dataset.from_tensor_slices()", but it didn't work, then i read it is a common problem and someone suggested me to try with generators, thus i tried with the following code, but again i got a huge event file (719 MB again)
def fetch_batch(x, y, batch):
i = 0
while i < batch:
yield (x[i,:,:,:], y[i])
i +=1
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
training_dataset = tf.data.Dataset.from_generator(fetch_batch,
args=[images, np.int32(labels), batch_size], output_types=(tf.float32, tf.int32),
output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape)))
file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())
I know in this case I could use tensorflow_datasets API and it would be easier, but this is a more general question, and it involves how to create datasets in general, not only using the mnist one. Could you explain to me what am i doing wrong? Thank you
TensorFlow implements a subset of the NumPy API, available as tf. experimental. numpy . This allows running NumPy code, accelerated by TensorFlow, while also allowing access to all of TensorFlow's APIs.
I guess it's because you are using args
in from_generator
. This will surely put the provided args
in the graph.
What you could do is define a function that will return a generator that will iterate through your set, something like (haven't tested):
def data_generator(images, labels):
def fetch_examples():
i = 0
while True:
example = (images[i], labels[i])
i += 1
i %= len(labels)
yield example
return fetch_examples
This would give in your example:
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
training_dataset = tf.data.Dataset.from_generator(data_generator(images, labels), output_types=(tf.float32, tf.int32),
output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape))).batch(batch_size)
file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())
Note that I changed fetch_batch
to fetch_examples
since you probably want to batch using the dataset utilities (.batch
).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With