Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient PyTorch DataLoader collate_fn function for inputs of various dimensions

I'm having trouble writing a custom collate_fn function for the PyTorch DataLoader class. I need the custom function because my inputs have different dimensions.

I'm currently trying to write the baseline implementation of the Stanford MURA paper. The dataset has a set of labeled studies. A study may contain more than one image. I created a custom Datasetclass that stacks these multiple images using torch.stack.

The stacked tensor is then provided as input to the model and the list of outputs is averaged to obtain a single output. This implementation works fine with DataLoader when batch_size=1. However, when I try to set the batch_size to 8, as is the case in the original paper, the DataLoader fails since it uses torch.stack to stack the batch and the inputs in my batch have variable dimensions (since each study can have multiple number of images).

In order to fix this, I tried to implement my custom collate_fn function.

def collate_fn(batch):
    imgs = [item['images'] for item in batch]
    targets = [item['label'] for item in batch]
    targets = torch.LongTensor(targets)
    return imgs, targets

Then in my training epoch loop, I loop through each batch like this:

for image, label in zip(*batch):
    label = label.type(torch.FloatTensor)
    # wrap them in Variable
    image = Variable(image).cuda()  
    label = Variable(label).cuda()
    # forward
    output = model(image)
    output = torch.mean(output)
    loss = criterion(output, label, phase)

However, this does not give me any improved timings on the epoch and still takes as long as it did with a batch size of only 1. I've also tried setting the batch size to 32 and that does not improve the timings either.

Am I doing something wrong? Is there a better approach to this?

like image 650
UserAnon Avatar asked Jan 07 '19 23:01

UserAnon


1 Answers

Very interesting problem! If I understand you correctly (and also checking the abstract of the paper), you have 40,561 images from 14,863 studies, where each study is manually labeled by radiologists as either normal or abnormal.

I believe the reason why you had the issue you faced was, say, for example, you created a stack for,

  1. study A - 12 images
  2. study B - 13 images
  3. study C - 7 images
  4. study D - 1 image, etc.

And you try to use a batch size of 8 during training which would fail when it gets to study D.

Therefore, is there a reason why we want to average the list of outputs in a study to fit a single label? Otherwise, I would simply collect all 40,561 images, assign the same label to all images from the same study (such that list of outputs in A is compared with a list of 12 labels).

Therefore, with a single dataloader you can shuffle across studies (if desired) and use the desired batch size during training.

I see this question has been around for a while, I hope it helps someone in the future :)

like image 95
Daniel Ajisafe Avatar answered Oct 19 '22 20:10

Daniel Ajisafe