Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow 2.0 dataset and dataloader

I am a pytorch user, and I am used to the data.dataset and data.dataloader api in pytorch. I am trying to build a same model with tensorflow 2.0, and I wonder whether there is an api that works similarly with these api in pytorch.

If there is no such api, can any of you tell me how people usually do to implement the data loading part in tensorflow ? I've used tensorflow 1, but never had an experience with dataset api. I've hard coded before. I hope there is something like overriding getitem with only index as an input.

Thanks much in advance.

like image 743
piljae.chae Avatar asked Oct 22 '19 13:10

piljae.chae


People also ask

Does TensorFlow have DataLoader?

Tensorflow uses multiple threads to load the data in memory and its dataloaders can prefetch the data before-hand so that your training loop doesn't get blocked while loading the data.

What is the difference between a dataset and DataLoader?

Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.


Video Answer


2 Answers

When using the tf.data API, you will usually also make use of the map function.

In PyTorch, your __getItem__ call basically fetches an element from your data structure given in __init__ and transforms it if necessary.

In TF2.0, you do the same by initializing a Dataset using one of the Dataset.from_... functions (see from_generator, from_tensor_slices, from_tensors); this is essentially the __init__ part of a PyTorch Dataset. Then, you can call map to do the element-wise manipulations you would have in __getItem__.

Tensorflow datasets are pretty much fancy iterators, so by design you don't access their elements using indices, but rather by traversing them.

The guide on tf.data is very useful and provides a wide variety of examples.

like image 107
Mat Avatar answered Oct 03 '22 13:10

Mat


I am not familiar with Pytorch but Tensorflow implements the Keras API which has the Sequence class that is:

Base object for fitting to a sequence of data, such as a dataset

https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence

This class contains getitem for an index.

like image 23
Eirik Moseng Avatar answered Oct 03 '22 13:10

Eirik Moseng