I use Tensorflow, but I'm writing documentation for users that will typically vary across deep learning frameworks.
When working with datasets that don't fit on the local filesystem (TB+) I sample data from a remote data store and write samples locally to a Tensorflow standardtfrecords
format.
During the first epoch of training I will have only sampled a few values, therefore an epoch of local data is very small, I train on it. On epoch 2 I re-examine what data files have been produced by my sampling subprocesses (now more) and train on the expanded set of local data files for the next epoch. Repeat the process each epoch. In this way I build up a local cache of samples and can evict older samples as I fill up the local storage. The local samples cache grows at about the time the model needs the variance the most (towards the latter part of training).
In Python/Tensorflow it's crucial that I not deserialize the data in the Python training loop process because the Python GIL can't support the data transfer rates (300-600 MB/sec, the data is raw scientific uncompressible), and thus GPU performance suffers when the Python GIL can't service the training loop fast.
Writing the samples to a tfrecords
file from subprocesses (python multiprocessing) allows tensorflow's native TFRecordsDataset
to do deserialization outside of Python and thus we sidestep the Python GIL issues, and I can saturate a GPU with high IO data rates.
I would like to know how I would address this issue in Pytorch. I'm writing about the sampling strategy that's being used, and want to provide specific recommendations to users of both Tensorflow and PyTorch, but I don't know the PyTorch preprocessing ecosystem well enough to write with sufficient detail.
Side note: the only purely Python based solution to support these data transfer rates may come in Python 3.8 with System V shared memory and multiprocessing, but I haven't tried that yet as support for it isn't quite sufficient (soon it will be). Existing multiprocessing solutions aren't sufficient because they require deserialization in the training loop process and thus lock the GIL during deserialization at high IO rates.
The most common approach for handling PyTorch training data is to write a custom Dataset class that loads data into memory, and then you serve up the data in batches using the built-in DataLoader class. This approach is simple but requires you to store all training data in memory.
The library also provides an IterableDataset reader of tfrecord files for PyTorch. Currently uncompressed and compressed gzip TFRecords are supported.
An iterator is an object representing a stream of data. You can create an iterator object by applying the iter() built-in function to an iterable. With the stream of data, we can use Python built-in next() function to get the next data element in the stream of data.
Actually, you can easily deserialize data in a subprocess by using torch.utils.data.DataLoader
. By setting num_workers
argument to 1 or a bigger value, you can spawn subprocesses with their own python interpreters and GILs.
loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
for batch_idx, data in enumerate(loader):
# loader in the main process does not claim GIL at this point
A Dataloader
requires a torch.utils.data.Dataset
to get data from. It may not be a trivial job to implement a proper subclass in your case. In case you need to recreate a Dataset
instance for every epoch, you can do something like this.
for epcoh in range(epochs):
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
for batch_idx, data in enumerate(loader):
# Do training
or even better
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
for epcoh in range(epochs):
last_batch_idx = (len(dset)-1) // loader.batch_size
for batch_idx, data in enumerate(loader):
# Prepare next loader in advance to avoid blocking
if batch_idx == last_batch_idx:
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
# Do training
As a side note, please note that it's CPU bound operation that is affected by GIL in most cases, not I/O bound operation, i.e., threading
will do for any purely I/O heavy operation and you don't even need subprocess
. For more information please refer to this question and this wikipedia article.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With