I am currently training a 3D CNN for binary classification with relatively sparse labels (~ 1% of voxels in label data correspond to target class).
In order to perform basic sanity checks during the training (e.g. does the network learn at all?) it would be handy to present the network with a small, handpicked subset of training examples having an above-average fraction of target class labels.
As suggested by the Pytorch documentation, I implemented my own dataset
class (inheriting from torch.utils.data.Dataset
) which provides training examples via it's __get_item__
method to the torch.utils.data.DataLoader
.
In the pytorch tutorials I found, the DataLoader
is used as an iterator to generate the training loop like so:
for i, data in enumerate(self.dataloader):
# Get training data
inputs, labels = data
# Train the network
# [...]
What I am wondering now is whether there exist a simple way to load a single or a couple of specific training examples (using a the linear index understood by Dataset
's __get_item__
method). However, DataLoader
does not have a __get_item__
method and repeatedly calling __next__
until I reach the desired index does not seem elegant.
Apparently one possible way to solve this would be to define a custom sampler
or batch_sampler
inheriting from the abstract torch.utils.data.Sampler
. But this seems over the top to retrieve a few specific samples.
I suppose I am overlooking something very simple and obvious here. Any advice appreciated!
Just in case anyone with a similar question comes across this at some point:
The quick-and-dirty workaround I ended up using was to bypass the dataloader
in the training loop by directly accessing it's associated dataset
attribute. Suppose we want to quickly check if our network learns at all by repeatedly presenting it a single, handpicked training example with linear index sample_idx
(as defined by the dataset class).
Then one can do something like this:
for i, _ in enumerate(self.dataloader):
# Get training data
# inputs, labels = data
inputs, labels = self.dataloader.dataset[sample_idx]
inputs = inputs.unsqueeze(0)
labels = labels.unsqueeze(0)
# Train the network
# [...]
EDIT:
One brief remark, since some people seem to be finding this workaround helpful: When using this hack I found it to be crucial to instantiate the DataLoader
with num_workers = 0
. Otherwise, memory segmentation errors might occur in which case you could end up with very weird looking training data.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With