Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to load numpy array in a tensorflow dataset

I'm trying to create a Dataset object in tensorflow 1.14 (I have some legacy code that i can't change for this specific project) starting from numpy arrays, but everytime i try i get everything copied on my graph and for this reason when i create an event log file it is huge (719 MB in this case).

Originally i tried using this function "tf.data.Dataset.from_tensor_slices()", but it didn't work, then i read it is a common problem and someone suggested me to try with generators, thus i tried with the following code, but again i got a huge event file (719 MB again)

def fetch_batch(x, y, batch):
    i = 0
    while i < batch:
        yield (x[i,:,:,:], y[i])
        i +=1

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train  
images = images/255

training_dataset = tf.data.Dataset.from_generator(fetch_batch, 
    args=[images, np.int32(labels), batch_size], output_types=(tf.float32, tf.int32), 
    output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape)))

file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())

I know in this case I could use tensorflow_datasets API and it would be easier, but this is a more general question, and it involves how to create datasets in general, not only using the mnist one. Could you explain to me what am i doing wrong? Thank you

like image 978
Stefano Avatar asked Nov 24 '19 16:11

Stefano


People also ask

Can I use NumPy in TensorFlow?

TensorFlow implements a subset of the NumPy API, available as tf. experimental. numpy . This allows running NumPy code, accelerated by TensorFlow, while also allowing access to all of TensorFlow's APIs.


1 Answers

I guess it's because you are using args in from_generator. This will surely put the provided args in the graph.

What you could do is define a function that will return a generator that will iterate through your set, something like (haven't tested):

def data_generator(images, labels):
  def fetch_examples():
    i = 0
    while True:
      example = (images[i], labels[i])
      i += 1
      i %= len(labels)
      yield example
  return fetch_examples

This would give in your example:

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train  
images = images/255

training_dataset = tf.data.Dataset.from_generator(data_generator(images, labels), output_types=(tf.float32, tf.int32), 
    output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape))).batch(batch_size)

file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())

Note that I changed fetch_batch to fetch_examples since you probably want to batch using the dataset utilities (.batch).

like image 132
Zaccharie Ramzi Avatar answered Oct 13 '22 07:10

Zaccharie Ramzi