Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch: Speed up data loading

I am using densenet121 to do cat/dog detection from Kaggle dataset. I enabled cuda and it appears that training is very fast. However, the data loading (or perhaps processing) appears to be very slow. Are there some ways to speed it up? I tried to play witch batch size, that didn't provide much help. I also changed num_workers from 0 to some positive numbers. Going from 0 to 2 reduces loading time by perhaps 1/3, increasing by more doesn't have additional effect. Are there some other ways I can speed loading things up?

This is my rough code (I am focused on learning, so it's not very organized):

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

data_dir = 'Cat_Dog_data'

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.5, 0.5, 0.5],
                                                            [0.5, 0.5, 0.5])])
test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor()])

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train',
                                  transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=64,
                                          num_workers=16, shuffle=True,
                                          pin_memory=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=64,
                                         num_workers=16)

model = models.densenet121(pretrained=True)

# Freeze parameters so we don't backprop through them
for param in model.parameters():
    param.requires_grad = False

from collections import OrderedDict

classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(1024, 500)),
    ('relu', nn.ReLU()),
    ('fc2', nn.Linear(500, 2)),
    ('output', nn.LogSoftmax(dim=1))
]))

model.classifier = classifier
model.cuda()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)

epochs = 30
steps = 0

import time

device = torch.device('cuda:0')

train_losses, test_losses = [], []
for e in range(epochs):
    running_loss = 0
    count = 0
    total_start = time.time()
    for images, labels in trainloader:
        start = time.time()
        images = images.cuda()
        labels = labels.cuda()

        optimizer.zero_grad()

        log_ps = model(images)
        loss = criterion(log_ps, labels)
        loss.backward()
        optimizer.step()
        elapsed = time.time() - start

        if count % 20 == 0:
            print("Optimized elapsed: ", elapsed, "count:", count)
            print("Total elapsed ", time.time() - total_start)
            total_start = time.time()
        count += 1

        running_loss += loss.item()
    else:
        test_loss = 0
        accuracy = 0
        for images, labels in testloader:
            images = images.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                model.eval()
                log_ps = model(images)
                test_loss += criterion(log_ps, labels)
                ps = torch.exp(log_ps)
                top_p, top_class = ps.topk(1, dim=1)
                compare = top_class == labels.view(*top_class.shape)
                accuracy += compare.type(torch.FloatTensor).mean()
        model.train()
        train_losses.append(running_loss / len(trainloader))
        test_losses.append(test_loss / len(testloader))

        print("Epoch: {}/{}.. ".format(e + 1, epochs),
              "Training Loss: {:.3f}.. ".format(
                  running_loss / len(trainloader)),
              "Test Loss: {:.3f}.. ".format(test_loss / len(testloader)),
              "Test Accuracy: {:.3f}".format(accuracy / len(testloader)))
like image 932
gruszczy Avatar asked Apr 23 '20 17:04

gruszczy


People also ask

What should be the Num_workers PyTorch?

num_workers , which denotes the number of processes that generate batches in parallel. A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation).


1 Answers

torchvision 0.8.0 version or greater

Actually torchvision now supports batches and GPU when it comes to transformations (this is done on torch.Tensors instead of PIL images), so one should use it as an initial improvement.

See here for more info about this release. Also those act as torch.nn.Module, hence can be used inside a model, for example:

transforms = torch.nn.Sequential(
    T.RandomCrop(224),
    T.RandomHorizontalFlip(p=0.3),
    T.ConvertImageDtype(torch.float),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)

Furthermore, those operations could be JITed possibly improving the performance even further.

torchvision < 0.8.0 (original answer)

Increasing batch_size won't help as torchvision performs transform on single image while it's loaded from your disk.

There are a couple of ways one could speed up data loading with increasing level of difficulty:

  • Improve image loading times
  • Load & normalize images and cache in RAM (or on disk)
  • Produce transformations and save them to disk
  • Apply non-cache'able transforms (rotations, flips, crops) in batched manner
  • Prefetching

1. Improve image loading

Easy improvements can be gained by installing Pillow-SIMD instead of original pillow. It is a drop-in replacement and could be faster (or so is claimed at least for Resize which you are using).

Alternatively, you could create your own data loading and processing with OpenCV as some say it's faster or check albumentations (though can't tell you whether those will improve the performance and might be a lot of time wasted for no gain except learning experience).

2. Load & normalize images & cache

You can use Python's LRU Cache functionality to cache some outputs.

You can also use torchdata which acts almost exactly like PyTorch's torch.utils.data.Dataset but allows caching to disk or in RAM (or mixed modes) with simple cache() on torchdata.Dataset (see github repository, disclaimer: i'm the author).

Remember: you have to load and normalize images, cache and after that use RandomRotation, RandomResizedCrop and RandomHorizontalFlip (as those change each time they are run).

3. Produce transformations and save them to disk

You would have to perform a lot of transformations on images, save them to disk and use this enhanced dataset afterwards. Once again that could be done with torchdata but it's really wasteful when it comes to I/O and hard drive and very inelegant solution. Furthermore it's "static" so the data would only last your for X epochs, it wouldn't be "infinite" generator with augmentations.

4. Batched transformations

torchvision does not support it so you would have to write those functions on your own. See this issue for justification. AFAIK no other 3rd party provides it either. For large batches it should speed up things but implementation is open question I think (correct me if I'm wrong).

5. Prefetch

IMO would be hardest to implement (though a really good idea for the project come to think about it). Basically you load data for the next iteration when your model trains. torch.utils.data.DataLoader does provide it, though there are some concerns (like workers pausing after their data got loaded). You can read PyTorch thread about it (not sure about it as I didn't verify on my own). Also, a lot of valuable insight provided by this comment and this blog post (though not sure how up to date those are).

All in all to substantially improve data loading you would need to get your hands quite dirty (or maybe there are libraries doing this some of those for PyTorch, if so,I would love to know about them).

Also remember to profile your changes, see torch.nn.bottleneck

EDIT: DALI project might be worth checking out, though AFAIK it has some problems with RAM memory growing linearly with number of epochs.

like image 101
Szymon Maszke Avatar answered Nov 15 '22 11:11

Szymon Maszke