Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does the __getitem__'s idx work within PyTorch's DataLoader?

I'm currently trying to use PyTorch's DataLoader to process data to feed into my deep learning model, but am facing some difficulty.

The data that I need is of shape (minibatch_size=32, rows=100, columns=41). The __getitem__ code that I have within the custom Dataset class that I wrote looks something like this:

def __getitem__(self, idx):
    x = np.array(self.train.iloc[idx:100, :])
    return x

The reason I wrote it like that is because I want the DataLoader to handle input instances of shape (100, 41) at a time, and we have 32 of these single instances.

However, I noticed that contrary to my initial belief the idx argument the DataLoader passes to the function is not sequential (this is crucial because my data is time series data). For example, printing the values gave me something like this:

idx = 206000
idx = 113814
idx = 80597
idx = 3836
idx = 156187
idx = 54990
idx = 8694
idx = 190555
idx = 84418
idx = 161773
idx = 177725
idx = 178351
idx = 89217
idx = 11048
idx = 135994
idx = 15067

Is this normal behavior? I'm posting this question because the data batches that are being returned are not what I initially wanted them to be.

The original logic that I used to preprocess the data before using the DataLoader was:

  1. Read data in from either txt or csv file.
  2. Calculate how many batches are in the data and slice the data accordingly. For example, since one input instance is of shape (100, 41) and 32 of these form one minibatch, we usually end up with around 100 or so batches and reshape the data accordingly.
  3. One input is of shape (32, 100, 41).

I'm not sure how else I should be handling the DataLoader hook methods. Any tips or advice are greatly appreciated. Thanks in advance.

like image 259
Sean Avatar asked Nov 13 '19 09:11

Sean


1 Answers

What defines the idx is the sampler or batch_sampler, as you can see here (open-source projects are your friend). In this code (and comment/docstring) you can see the difference between sampler and batch_sampler. If you look here you'll see how the index is chosen:

def __next__(self):
    index = self._next_index()

# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
    return next(self._sampler_iter)

# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)

# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler

Pay attention that this is the _SingleProcessDataLoaderIter implementation; you can find the _MultiProcessingDataLoaderIter here (ofc, which one is used depends on the num_workers value, as you can see here). Going back to the samplers, assuming your Dataset is not _DatasetKind.Iterable and that you are not providing a custom sampler, it means you are either using (dataloader.py#L212-L215):

if shuffle:
    sampler = RandomSampler(dataset)
else:
    sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

Let's take a look at how the default BatchSampler builds a batch:

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

Very simple: it gets indices from the sampler until the desired batch_size is reached.

Now the question "How does the __getitem__'s idx work within PyTorch's DataLoader?" can be answered by seeing how each default sampler works.

  • SequentialSampler (this is the full implementation -- very simple, isn't it?):
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)
  • RandomSampler (let's see only the __iter__ implementation):
def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

Therefore, as you did not provide any code, we can only assume:

  1. You are using shuffle=True in your DataLoader or
  2. You are using a custom sampler or
  3. Your Dataset is _DatasetKind.Iterable
like image 78
Berriel Avatar answered Sep 27 '22 21:09

Berriel