Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simple way to load specific sample using Pytorch dataloader

I am currently training a 3D CNN for binary classification with relatively sparse labels (~ 1% of voxels in label data correspond to target class).

In order to perform basic sanity checks during the training (e.g. does the network learn at all?) it would be handy to present the network with a small, handpicked subset of training examples having an above-average fraction of target class labels.

As suggested by the Pytorch documentation, I implemented my own dataset class (inheriting from torch.utils.data.Dataset) which provides training examples via it's __get_item__ method to the torch.utils.data.DataLoader.

In the pytorch tutorials I found, the DataLoader is used as an iterator to generate the training loop like so:

for i, data in enumerate(self.dataloader):

    # Get training data
    inputs, labels = data

    # Train the network
    # [...]

What I am wondering now is whether there exist a simple way to load a single or a couple of specific training examples (using a the linear index understood by Dataset's __get_item__ method). However, DataLoader does not have a __get_item__ method and repeatedly calling __next__ until I reach the desired index does not seem elegant.

Apparently one possible way to solve this would be to define a custom sampler or batch_sampler inheriting from the abstract torch.utils.data.Sampler. But this seems over the top to retrieve a few specific samples.

I suppose I am overlooking something very simple and obvious here. Any advice appreciated!

like image 379
Florian Drawitsch Avatar asked Feb 19 '19 18:02

Florian Drawitsch


1 Answers

Just in case anyone with a similar question comes across this at some point:

The quick-and-dirty workaround I ended up using was to bypass the dataloader in the training loop by directly accessing it's associated dataset attribute. Suppose we want to quickly check if our network learns at all by repeatedly presenting it a single, handpicked training example with linear index sample_idx (as defined by the dataset class).

Then one can do something like this:

for i, _ in enumerate(self.dataloader):

    # Get training data
    # inputs, labels = data

    inputs, labels = self.dataloader.dataset[sample_idx]
    inputs = inputs.unsqueeze(0)
    labels = labels.unsqueeze(0)

    # Train the network
    # [...]

EDIT:

One brief remark, since some people seem to be finding this workaround helpful: When using this hack I found it to be crucial to instantiate the DataLoader with num_workers = 0. Otherwise, memory segmentation errors might occur in which case you could end up with very weird looking training data.

like image 139
Florian Drawitsch Avatar answered Sep 25 '22 13:09

Florian Drawitsch