Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow tf.data.Dataset and bucketing

For an LSTM network, I've seen great improvements with bucketing.

I've come across the bucketing section in the TensorFlow docs which (tf.contrib).

Though in my network, I am using the tf.data.Dataset API, specifically I'm working with TFRecords, so my input pipeline looks something like this

dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})

How can I incorporate the bucketing method into a the tf.data.Dataset pipeline?

If it matters, in every record in the TFRecords file I have the sequence length saved as an integer.

like image 762
bluesummers Avatar asked May 30 '18 13:05

bluesummers


People also ask

What does TF data dataset do?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What does TF data dataset From_tensor_slices do?

from_tensor_slices() It removes the first dimension and use it as a dataset dimension.

Is TF data dataset a generator?

data. Dataset objects as generators for the training of a machine learning model on Tensorflow, with parallelized processing. The tf. data pipeline is now the gold standard for building an efficient data pipeline for machine learning applications with TensorFlow.

Which are the three main methods of getting data into a TensorFlow program?

Feeding: Python code provides the data when running each step. Reading from files: an input pipeline reads the data from files at the beginning of a TensorFlow graph. Preloaded data: a constant or variable in the TensorFlow graph holds all the data (for small data sets).


1 Answers

Various bucketing use cases using Dataset API are explained well here.

bucket_by_sequence_length() example:

def elements_gen():
   text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
   label = [1, 2, 1, 2]
   for x, y in zip(text, label):
       yield (x, y)

def element_length_fn(x, y):
   return tf.shape(x)[0]

dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                     output_shapes=([None],[]),
                                     output_types=(tf.int32, tf.int32))

dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                              bucket_batch_sizes=[2, 2, 2],
                                                              bucket_boundaries=[0, 8]))

batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:

   for _ in range(2):
      print('Get_next:')
      print(sess.run(batch))

Output:

Get_next:
(array([[1, 2, 3, 0, 0],
   [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
Get_next:
(array([[1, 2, 0, 0],
   [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))
like image 57
vijay m Avatar answered Oct 04 '22 00:10

vijay m