Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How Batch learning in Pytorch is performed?

When you look at how network architecture is built inside the pytorch code, we need to extend the torch.nn.Module and inside __init__, we define the module of networks and pytorch is going to track the gradients of parameters of these modules. Then inside the forward function, we define how the forward pass should be done for our network.

The thing I do not understand here is how the batch learning is going to occur. In none of the definition above including the forward function, we do not care about the dimension of batch of the input to our network. The only thing we need to set to perform batch learning is to add an extra dimension to the input which corresponds to the batch size but nothing inside the network definition is going to be changed if we are working with batch learning. At least, this is the thing I have seen in the codes here.

So, if all the things I have explained so far is correct (I would really appreciate if you let me know if I have misunderstood something), how batch learning is performed if nothing is declared regarding the batch size inside the definition of our network class (the class that inherits torch.nn.Module)? Specifically, I am interested to know how batch gradient descent algorithm is implemented in pytorch when we just set nn.MSELoss with batch dimension.

like image 691
Infintyyy Avatar asked Jun 19 '19 01:06

Infintyyy


People also ask

How does batch size work in PyTorch?

PyTorch dataloader batch size Batch size is defined as the number of samples processed before the model is updated. The batch size is equal to the number of samples in the training data.

What does Batch mean in PyTorch?

The meaning of batch size is loading [batch size] training data in one iteration. If your batch size is 100 then you should be getting 100 data at one iteration. batch size doesnt equal to no. of iteration unless there is a coincidence.


1 Answers

Check this:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()         

    def forward(self, x):
        print("Hi ma")        
        print(x)
        x = F.relu(x)
        return x

n = Net()
r = n(torch.tensor(-1))
print(r)
r = n.forward(torch.tensor(1)) #not planned to call directly
print(r)

out:

Hi ma
tensor(-1)
tensor(0)
Hi ma
tensor(1)
tensor(1)

Thing to remember is that forward should not be called directly. The PyTorch made this Module object n callable. They implemented callable like:

 def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        hook(self, input)
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            raise RuntimeError(
                "forward hooks should never return any values, but '{}'"
                "didn't return None".format(hook))
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

And just n() will call forward automatically.

In general, __init__ defines the module structure and forward() defines operations on a single batch.

That operation may repeat if needed for some structure elements or you may call functions on tensors directly like we did x = F.relu(x).

You got this great, everything in PyTorch will do in batches (mini-batches), since the PyTorch is optimized to work this way.

This means when you read the image, you will not read the single one, but one bs batches of images.

like image 88
prosti Avatar answered Oct 22 '22 00:10

prosti