Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use dataset.shard in tensorflow?

Recently I am looking into the dataset API in Tensorflow, and there is a method dataset.shard() which is for distributed computations.

This is what's stated in Tensorflow's documentation:

Creates a Dataset that includes only 1/num_shards of this dataset.

d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)

This method is said to return a portion of the original dataset. If I have two workers, am I supposed to do:

d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()

for worker_id in workers:
    with tf.device(worker_id):
        if worker_id == 0:
            data = iterator_0.get_next()
        else:
            data = iterator_1.get_next()
        ......

Because the documentation did not specify how to make subsequent calls, I am a bit confused here.

Thanks!

like image 826
Jiang Wenbo Avatar asked Feb 13 '18 13:02

Jiang Wenbo


People also ask

What is shard in Tensorflow?

shard is when you have multiple workers, and each worker should have its own subset of the dataset.

How do you split a dataset in TF?

A robust way to split dataset into two parts is to first deterministically map every item in the dataset into a bucket with, for example, tf. strings. to_hash_bucket_fast . Then you can split the dataset into two by filtering by the bucket.


1 Answers

You should take a look at the tutorial on Distributed TensorFlow first to better understand how it works.

You have multiple workers, that each run the same code but with a small difference: each worker will have a different FLAGS.worker_index.

When you use tf.data.Dataset.shard, you will supply this worker index and the data will be split between workers equally.

Here is an example with 3 workers.

dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)


iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()

# Suppose you have 3 workers in total
with tf.Session() as sess:
    for i in range(2):
        print(sess.run(res))

We will have the output:

  • 0, 3 on worker 0
  • 1, 4 on worker 1
  • 2, 5 on worker 2
like image 197
Olivier Moindrot Avatar answered Oct 12 '22 12:10

Olivier Moindrot