Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient way to iterate over tf.data.Dataset

I want to know which is the most efficient way to iterate through a tf.data.Dataset in TensorFlow 2.4.

I am using the typical:

for example in dataset:
    code

However, I have measured the wall time and, since my dataset is huge, it takes too much time for computing the loop. Is there any other option that reduces the computing time?.

like image 507
carlorop Avatar asked Oct 23 '25 04:10

carlorop


1 Answers

You can use .map(map_func) function which is an efficient way to apply some preprocessing on each sample in your dataset. It runs the map_func on each sample of your dataset in parallel. You can even set number of parallel calls by num_parallel_calls argument. [Reference]

Here is an example from tensorflow website:

dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1) # instead of adding 1 to each sample in a for loop
list(dataset.as_numpy_iterator())      # ==> [ 2, 3, 4, 5, 6 ]

You can pass a function as well:

def my_map(x): # if dataset has y, it should be like "def my_map(x,y)" and "return x,y"
  return x+1  
                                                  
dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(my_map)          # instead of adding 1 to each sample in a for loop
list(dataset.as_numpy_iterator())      # ==> [ 2, 3, 4, 5, 6 ]
like image 175
Kaveh Avatar answered Oct 25 '25 16:10

Kaveh