Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you send arguments to a generator function using tf.data.Dataset.from_generator()?

I would like to create a number of tf.data.Dataset using the from_generator() function. I would like to send an argument to the generator function (raw_data_gen). The idea is that the generator function will yield different data depending on the argument sent. In this way I would like raw_data_gen to be able to provide either training, validation or test data.

training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))

validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))

test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))

The error message I get when I try to call from_generator() in this way is:

TypeError: from_generator() got an unexpected keyword argument 'args'

Here is the raw_data_gen function although I'm not sure if you will need this as my hunch is that the problem is with the call of from_generator():

def raw_data_gen(train_val_or_test):

    if train_val_or_test == 1:        
        #For every filename collected in the list
        for filename, lab in training_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 2:
        #For every filename collected in the list
        for filename, lab in validation_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 3:
        #For every filename collected in the list
        for filename, lab in test_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    else:
        print("generator function called with an argument not in [1, 2, 3]")
        raise ValueError()
like image 221
michael_question_answerer Avatar asked Sep 21 '18 11:09

michael_question_answerer


People also ask

Is TF data dataset a generator?

Dataset objects as generators for the training of a machine learning model on Tensorflow, with parallelized processing. The tf. data pipeline is now the gold standard for building an efficient data pipeline for machine learning applications with TensorFlow.

What does TF data dataset From_tensor_slices do?

With that knowledge, from_tensors makes a dataset where each input tensor is like a row of your dataset, and from_tensor_slices makes a dataset where each input tensor is column of your data; so in the latter case all tensors must be the same length, and the elements (rows) of the resulting dataset are tuples with one ...

How does TF data dataset work?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.


1 Answers

You need to define a new function based on raw_data_gen that doesn't take any arguments. You can use the lambda keyword to do this.

training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...

Now, we are passing a function to from_generator that doesn't take any arguments, but that will simply act as raw_data_gen with the argument set to 1. You can use the same scheme for the validation and test sets, passing 2 and 3 respectively.

like image 192
xdurch0 Avatar answered Oct 12 '22 13:10

xdurch0