I understand Dataset API is a sort of a iterator which does not load the entire dataset into memory, because of which it is unable to find the size of the Dataset. I am talking in the context of large corpus of data that is stored in text files or tfRecord files. These files are generally read using tf.data.TextLineDataset
or something similar. It is trivial to find the size of dataset loaded using tf.data.Dataset.from_tensor_slices
.
The reason I am asking the size of the Dataset is the following: Let's say my Dataset size is 1000 elements. Batch size = 50 elements. Then training steps/batches (assuming 1 epoch) = 20. During these 20 steps, I would like to exponentially decay my learning rate from 0.1 to 0.01 as
tf.train.exponential_decay(
learning_rate = 0.1,
global_step = global_step,
decay_steps = 20,
decay_rate = 0.1,
staircase=False,
name=None
)
In the above code, I have "and" would like to set decay_steps = number of steps/batches per epoch = num_elements/batch_size
. This can be calculated only if the number of elements in the dataset is known in advance.
Another reason to know the size in advance is to split the data into train and test sets using tf.data.Dataset.take()
, tf.data.Dataset.skip()
methods.
PS: I am not looking for brute-force approaches like iterating through the whole dataset and updating a counter to count the number of elements or putting a very large batch size and then finding the size of the resultant dataset, etc.
length = numElements(dataset) gets the number of elements in the top-level dataset. To get the number of elements of a nested data set, use numElements with the nested data set.
To get the shape of a tensor, you can easily use the tf. shape() function. This method will help the user to return the shape of the given tensor.
TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.
TensorFlow Datasets. TensorFlow Datasets provides a collection of datasets ready to use with TensorFlow. It handles downloading and preparing the data and constructing a tf.data.Dataset. Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0.
There are two distinct ways to create a dataset: A data source constructs a Dataset from data stored in memory or in one or more files. A data transformation constructs a dataset from one or more tf.data.Dataset objects. To create an input pipeline, you must start with a data source.
In Tensorflow 2.0 it’s good practice to load your data using the tf.data.Dataset API. However, using this isn’t always straightforward. There are multiple ways you can create such a dataset. In this article we will look at several of them. For all of these methods we will use the same model and parameters.
The tf.data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.
You can easily get the number of data samples using :
dataset.__len__()
You can get each element like this:
for step, element in enumerate(dataset.as_numpy_iterator()):
... print(step, element)
You can also get the shape of one sample:
dataset.element_spec
If you want to take specific elements you can use shard method as well.
I realize this question is two years old, but perhaps this answer will be useful.
If you are reading your data with tf.data.TextLineDataset
, then a way to get the number of samples could be to count the number of lines in all of the text files you are using.
Consider the following example:
import random
import string
import tensorflow as tf
filenames = ["data0.txt", "data1.txt", "data2.txt"]
# Generate synthetic data.
for filename in filenames:
with open(filename, "w") as f:
lines = [random.choice(string.ascii_letters) for _ in range(random.randint(10, 100))]
print("\n".join(lines), file=f)
dataset = tf.data.TextLineDataset(filenames)
Trying to get the length with len
raises a TypeError
:
len(dataset)
But one can calculate the number of lines in a file relatively quickly.
# https://stackoverflow.com/q/845058/5666087
def get_n_lines(filepath):
i = -1
with open(filepath) as f:
for i, _ in enumerate(f):
pass
return i + 1
n_lines = sum(get_n_lines(f) for f in filenames)
In the above, n_lines
is equal to the number of elements found when iterating over the dataset with
for i, _ in enumerate(dataset):
pass
n_lines == i + 1
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With