Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to adapt the gpu batch size during training?

I found surprising that I could not find any resources online on how to dynamically adapt the GPU batch size without halting training.

The idea is the following:

1) Have a training script that is (almost) agnostic to the GPU in use. The batch size will dynamically adjust without interference of the user or need for tunning.

2) Still being able to specifying the desired training batch size, even if too big to fit in the biggest known GPU.

For instance, let's say I want to train a model using a batch size of 4096 images, each image 1024x1024. Let's also say that I have access to a server with different NVidea GPUs, but I don't know which one will be assigned to me in advance. (Or that everybody wants to use the biggest GPU and that I am left waiting a long time before it is my term).

I want my training script to find the max batch size (let's say it is 32 images per GPU batch), and only update the optimizer when all 4096 images have been processed (one training batch = 128 GPU batches).

like image 754
Victor Zuanazzi Avatar asked Oct 16 '22 09:10

Victor Zuanazzi


2 Answers

There are different ways of solving this problem. But if specifying the GPU that can do the job, or using multiple GPUs are not an option, then it is handy to dynamically adapt the GPU batch size.

I prepared this repo with an illustrative training example in pytorch (it should work similarly in TensorFlow)

In the code below, the try/except is used to try different GPU batch sizes without halting training. When the batch becomes too large, it is downsized and the adaptation is turned off. Please check the repo for the implementation details and possible bug fixes.

It is also implemented a technique called Batch Spoofing, which performs a number of forward passes before doing the backpropagation. In PyTorch it only requires replacing the optimizer.zero_grad().

import torch
import torchvision
import torch.optim as optim
import torch.nn as nn

# Example of how to use it with Pytorch
if __name__ == "__main__":

    # #############################################################
    # 1) Initialize the dataset, model, optimizer and loss as usual.
    # Initialize a fake dataset

    trainset = torchvision.datasets.FakeData(size=1_000_000,
                                             image_size=(3, 224, 224),
                                             num_classes=1000)

    # initialize the model, loss and SGD-based optimizer
    resnet = torchvision.models.resnet152(pretrained=True,
                                          progress=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(resnet.parameters(), lr=0.01)

    continue_training = True  # criteria to stop the training

    # #############################################################
    # 2) Set parameters for the adaptive batch size
    adapt = True  # while this is true, the algorithm will perform batch adaptation
    gpu_batch_size = 2  # initial gpu batch_size, it can be super small
    train_batch_size = 2048  # the train batch size of desire

    # Modified training loop to allow for adaptive batch size
    while continue_training:

        # #############################################################
        # 3) Initialize dataloader and batch spoofing parameter
        # Dataloader has to be reinicialized for each new batch size.
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=int(gpu_batch_size),
                                                  shuffle=True)

        # Number of repetitions for batch spoofing
        repeat = max(1, int(train_batch_size / gpu_batch_size))

        try:  # This will make sure that training is not halted when the batch size is too large

            # #############################################################
            # 4) Epoch loop with batch spoofing
            optimizer.zero_grad()  # done before training because of batch spoofing.

            for i, (x, y) in enumerate(trainloader):

                y_pred = resnet(x)
                loss = criterion(y_pred, y)
                loss.backward()

                # batch spoofing
                if not i % repeat:
                    optimizer.step()
                    optimizer.zero_grad()

                # #############################################################
                # 5) Adapt batch size while no RuntimeError is rased.
                # Increase batch size and get out of the loop
                if adapt:
                    gpu_batch_size *= 2
                    break

                # Stopping criteria for training
                if i > 100:
                    continue_training = False

        # #############################################################
        # 6) After the largest batch size is found, the training progresses with the fixed batch size.
        # CUDA out of memory is a RuntimeError, the moment we will get to it when our batch size is too large.
        except RuntimeError as run_error:
            gpu_batch_size /= 2  # resize the batch size for the biggest that works in memory
            adapt = False  # turn off the batch adaptation

            # Number of repetitions for batch spoofing
            repeat = max(1, int(train_batch_size / gpu_batch_size))

            # Manual check if the RuntimeError was caused by the CUDA or something else.
            print(f"---\nRuntimeError: \n{run_error}\n---\n Is it a cuda error?")

If you have code that can do similarly in Tensorflow, Caffe or others, please share!

like image 119
Victor Zuanazzi Avatar answered Oct 20 '22 10:10

Victor Zuanazzi


how to dynamically adapt the GPU batch size without halting training

There is a very similar question that uses random sampler for the job.

I will just have to add another option: DataLoader has collate_fn you could use for altering the bs.

collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

like image 33
prosti Avatar answered Oct 20 '22 09:10

prosti