Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Most efficient way to use a large data set for PyTorch?

Perhaps this question has been asked before, but I'm having trouble finding relevant info for my situation.

I'm using PyTorch to create a CNN for regression with image data. I don't have a formal, academic programming background, so many of my approaches are ad-hoc and just terribly inefficient. May times I can go back through my code and clean things up later because the inefficiency is not so drastic that performance is significantly affected. However, in this case, my method for using the image data takes a long time, uses a lot of memory, and it is done every time I want to test a change in the model.

What I've done is essentially loaded the image data into numpy arrays, saved those arrays in an .npy file, and then when I want to use said data for the model I import all of the data in that file. I don't think the data set is really THAT large, as it is comprised of 5000, 3 color channel images of size 64x64. Yet my memory usage shoots up to 70%-80% (out of 16gb) when it is being loaded, and it takes 20-30 seconds to load in every time.

My guess is that I'm being dumb about the way I'm loading it in, but frankly I'm not sure what the standard is. Should I, in some way, put the image data somewhere before I need it, or should the data be loaded directly from the image files? And in either case, what is the best, most efficient way to do that, independent of file structure?

I would really appreciate any help on this.

like image 956
Doug MacArthur Avatar asked Dec 01 '18 23:12

Doug MacArthur


People also ask

How does PyTorch work with large datasets?

The most common approach for handling PyTorch training data is to write a custom Dataset class that loads data into memory, and then you serve up the data in batches using the built-in DataLoader class. This approach is simple but requires you to store all training data in memory.

Which is used to handle bigger datasets?

6) Use a Relational Database #Relational #databases provide a standard way of storing and accessing very large datasets. Internally, the data is stored on a disk can be progressively loaded in batches and can be queried using a standard query language (SQL).


2 Answers

For speed I would advise to used HDF5 or LMDB:

Reasons to use LMDB:

LMDB uses memory-mapped files, giving much better I/O performance. Works well with really large datasets. The HDF5 files are always read entirely into memory, so you can’t have any HDF5 file exceed your memory capacity. You can easily split your data into several HDF5 files though (just put several paths to h5 files in your text file). Then again, compared to LMDB’s page caching the I/O performance won’t be nearly as good. [http://deepdish.io/2015/04/28/creating-lmdb-in-python/]

If you decide to used LMDB:

ml-pyxis is a tool for creating and reading deep learning datasets using LMDBs.*(I am co author of this tool)

It allows to create binary blobs (LMDB) and they can be read quite fast. The link above comes with some simple examples on how to create and read the data. Including python generators/ iteratos .

This notebook has an example on how to create a dataset and read it paralley while using pytorch.

If you decide to use HDF5:

PyTables is a package for managing hierarchical datasets and designed to efficiently and easily cope with extremely large amounts of data.

https://www.pytables.org/

like image 112
OddNorg Avatar answered Oct 26 '22 02:10

OddNorg


Here is a concrete example to demonstrate what I meant. This assumes that you've already dumped the images into an hdf5 file (train_images.hdf5) using h5py.

import h5py
hf = h5py.File('train_images.hdf5', 'r')

group_key = list(hf.keys())[0]
ds = hf[group_key]

# load only one example
x = ds[0]

# load a subset, slice (n examples) 
arr = ds[:n]

# should load the whole dataset into memory.
# this should be avoided
arr = ds[:]

In simple terms, ds can now be used as an iterator which gives images on the fly (i.e. it doesn't load anything in memory). This should make the whole run time blazing fast.

for idx, img in enumerate(ds):
   # do something with `img`
like image 27
kmario23 Avatar answered Oct 26 '22 01:10

kmario23