Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Tensorflow: How to find the size of a tf.data.Dataset API object

I understand Dataset API is a sort of a iterator which does not load the entire dataset into memory, because of which it is unable to find the size of the Dataset. I am talking in the context of large corpus of data that is stored in text files or tfRecord files. These files are generally read using tf.data.TextLineDataset or something similar. It is trivial to find the size of dataset loaded using tf.data.Dataset.from_tensor_slices.

The reason I am asking the size of the Dataset is the following: Let's say my Dataset size is 1000 elements. Batch size = 50 elements. Then training steps/batches (assuming 1 epoch) = 20. During these 20 steps, I would like to exponentially decay my learning rate from 0.1 to 0.01 as

    learning_rate = 0.1,
    global_step = global_step,
    decay_steps = 20,
    decay_rate = 0.1,

In the above code, I have "and" would like to set decay_steps = number of steps/batches per epoch = num_elements/batch_size. This can be calculated only if the number of elements in the dataset is known in advance.

Another reason to know the size in advance is to split the data into train and test sets using tf.data.Dataset.take(), tf.data.Dataset.skip() methods.

PS: I am not looking for brute-force approaches like iterating through the whole dataset and updating a counter to count the number of elements or putting a very large batch size and then finding the size of the resultant dataset, etc.

like image 226
omsrisagar Avatar asked Jun 19 '18 01:06


People also ask

How do you find the number of elements in a data set?

length = numElements(dataset) gets the number of elements in the top-level dataset. To get the number of elements of a nested data set, use numElements with the nested data set.

How do I know what shape my TF dataset is?

To get the shape of a tensor, you can easily use the tf. shape() function. This method will help the user to return the shape of the given tensor.

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.

What are TensorFlow datasets?

TensorFlow Datasets. TensorFlow Datasets provides a collection of datasets ready to use with TensorFlow. It handles downloading and preparing the data and constructing a tf.data.Dataset. Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0.

How do I create a dataset in TF?

There are two distinct ways to create a dataset: A data source constructs a Dataset from data stored in memory or in one or more files. A data transformation constructs a dataset from one or more tf.data.Dataset objects. To create an input pipeline, you must start with a data source.

How do I load data in TensorFlow?

In Tensorflow 2.0 it’s good practice to load your data using the tf.data.Dataset API. However, using this isn’t always straightforward. There are multiple ways you can create such a dataset. In this article we will look at several of them. For all of these methods we will use the same model and parameters.

What is TF Data API?

The tf.data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

2 Answers

You can easily get the number of data samples using :


You can get each element like this:

for step, element in enumerate(dataset.as_numpy_iterator()):
...     print(step, element)

You can also get the shape of one sample:


If you want to take specific elements you can use shard method as well.

like image 89
Ixtiyor Majidov Avatar answered Oct 18 '22 13:10

Ixtiyor Majidov

I realize this question is two years old, but perhaps this answer will be useful.

If you are reading your data with tf.data.TextLineDataset, then a way to get the number of samples could be to count the number of lines in all of the text files you are using.

Consider the following example:

import random
import string
import tensorflow as tf

filenames = ["data0.txt", "data1.txt", "data2.txt"]

# Generate synthetic data.
for filename in filenames:
    with open(filename, "w") as f:
        lines = [random.choice(string.ascii_letters) for _ in range(random.randint(10, 100))]
        print("\n".join(lines), file=f)

dataset = tf.data.TextLineDataset(filenames)

Trying to get the length with len raises a TypeError:


But one can calculate the number of lines in a file relatively quickly.

# https://stackoverflow.com/q/845058/5666087
def get_n_lines(filepath):
    i = -1
    with open(filepath) as f:
        for i, _ in enumerate(f):
    return i + 1

n_lines = sum(get_n_lines(f) for f in filenames)

In the above, n_lines is equal to the number of elements found when iterating over the dataset with

for i, _ in enumerate(dataset):
n_lines == i + 1
like image 22
jakub Avatar answered Oct 18 '22 15:10
