Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to acquire tf.data.dataset's shape?

I know dataset has output_shapes, but it shows like below:

data_set: DatasetV1Adapter shapes: {item_id_hist: (?, ?), tags: (?, ?), client_platform: (?,), entrance: (?,), item_id: (?,), lable: (?,), mode: (?,), time: (?,), user_id: (?,)}, types: {item_id_hist: tf.int64, tags: tf.int64, client_platform: tf.string, entrance: tf.string, item_id: tf.int64, lable: tf.int64, mode: tf.int64, time: tf.int64, user_id: tf.int64}

How can I get the total number of my data?

like image 701
cao xiangyu Avatar asked May 20 '19 09:05

cao xiangyu


People also ask

How do you get the shape of a TF dataset?

To get the shape of a tensor, you can easily use the tf. shape() function. This method will help the user to return the shape of the given tensor.

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.

Is TF data dataset a generator?

Another common data source that can easily be ingested as a tf. data. Dataset is the python generator.


3 Answers

Where the length is known you can call:

tf.data.experimental.cardinality(dataset)

but if this fails then, it's important to know that a TensorFlow Dataset is (in general) lazily evaluated so this means that in the general case we may need to iterate over every record before we can find the length of the dataset.

For example, assuming you have eager execution enabled and its a small 'toy' dataset that fits comfortably in memory you could just enumerate it into a new list and grab the last index (then add 1 because lists are zero-indexed):

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1

Of course this is inefficient at best and, for large datasets, will fail entirely because everything needs to fit into memory for the list. in such circumstances I can't see any alternative other than to iterate through the records keeping a manual count.

like image 167
Stewart_R Avatar answered Oct 19 '22 21:10

Stewart_R


Code as below:

dataset_to_numpy = list(dataset.as_numpy_iterator())
shape = tf.shape(dataset_to_numpy)
print(shape)

It produces output like this:

tf.Tensor([1080   64   64    3], shape=(4,), dtype=int32)

It's simple to write the code, but it still costs time to iterate the dataset. For more info about tf.data.Dataset, check this link.

like image 26
Damon Roux Avatar answered Oct 19 '22 21:10

Damon Roux


As of 4/15/2022 with the TF v2.8, you can get the results by using

dataset.cardinality().numpy()

ref: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cardinality

like image 2
Vincent Yuan Avatar answered Oct 19 '22 20:10

Vincent Yuan