I'm having trouble writing a custom collate_fn
function for the PyTorch DataLoader
class. I need the custom function because my inputs have different dimensions.
I'm currently trying to write the baseline implementation of the Stanford MURA paper. The dataset has a set of labeled studies. A study may contain more than one image. I created a custom Dataset
class that stacks these multiple images using torch.stack
.
The stacked tensor is then provided as input to the model and the list of outputs is averaged to obtain a single output. This implementation works fine with DataLoader
when batch_size=1
. However, when I try to set the batch_size
to 8, as is the case in the original paper, the DataLoader
fails since it uses torch.stack
to stack the batch and the inputs in my batch have variable dimensions (since each study can have multiple number of images).
In order to fix this, I tried to implement my custom collate_fn
function.
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
Then in my training epoch loop, I loop through each batch like this:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
However, this does not give me any improved timings on the epoch and still takes as long as it did with a batch size of only 1. I've also tried setting the batch size to 32 and that does not improve the timings either.
Am I doing something wrong? Is there a better approach to this?
Very interesting problem! If I understand you correctly (and also checking the abstract of the paper), you have 40,561 images from 14,863 studies, where each study is manually labeled by radiologists as either normal or abnormal.
I believe the reason why you had the issue you faced was, say, for example, you created a stack for,
And you try to use a batch size of 8 during training which would fail when it gets to study D.
Therefore, is there a reason why we want to average the list of outputs in a study to fit a single label? Otherwise, I would simply collect all 40,561 images, assign the same label to all images from the same study (such that list of outputs in A is compared with a list of 12 labels).
Therefore, with a single dataloader you can shuffle across studies (if desired) and use the desired batch size during training.
I see this question has been around for a while, I hope it helps someone in the future :)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With