Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does PyTorch module do the back prop

While following the instructions on extending PyTorch - adding a module, I noticed while extending Module, we don't really have to implement the backward function. The only thing we need is to apply the Function instance in the forward function and PyTorch can automatically call the backward one in the Function instance when doing the back prop. This seems like magic to me as we didn't even register the Function instance we used. I looked into the source code but didn't find anything related. Could anyone kindly point me out a place that all those actually happened?

like image 567
NoSegfault Avatar asked Apr 01 '18 04:04

NoSegfault


People also ask

How does PyTorch backward works?

The backward() method in Pytorch is used to calculate the gradient during the backward pass in the neural network. If we do not call this backward() method then gradients are not calculated for the tensors. The gradient of a tensor is calculated for the one having requires_grad is set to True.

How does Autograd in PyTorch work?

Autograd is reverse automatic differentiation system. Conceptually, autograd records a graph recording all of the operations that created the data as you execute operations, giving you a directed acyclic graph whose leaves are the input tensors and roots are the output tensors.

What does Requires_grad true do?

requires_grad = True they start forming a backward graph that tracks every operation applied on them to calculate the gradients using something called a dynamic computation graph (DCG) (explained further in the post).

How does automatic differentiation work in PyTorch?

PyTorch computes the gradient of a function with respect to the inputs by using automatic differentiation. Automatic differentiation is a technique that, given a computational graph, calculates the gradients of the inputs. Automatic differentiation can be performed in two different ways; forward and reverse mode.


1 Answers

Not having to implement backward() is the reason PyTorch or any other DL framework is so valuable. In fact, implementing backward() should only be done in very specific cases where you need to mess with the network's gradient (or when you create a custom Function that can't be expressed using PyTorch's built-in functions).

PyTorch computes backward gradients using a computational graph which keeps track of what operations have been done during your forward pass. Any operation done on a Variable implicitly get registered here. Then it's a matter of traversing the graph backward from the variable where it was called, and applying derivative chain rule to compute the gradients.

PyTorch's About page has a nice visualization of the graph and how it generally works. I'd also recommend looking up compute graphs and autograd mechanism on Google if you want more details.

EDIT: The source code where all this happens would be in the C part of PyTorch's codebase, where the actual graph is implemented. After some digging around, I found this:

/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(const variable_list& inputs) {
    profiler::RecordFunction rec(this);
    if (jit::tracer::isTracingVar(inputs)) {
        return traced_apply(inputs);
    }
    return apply(inputs);
}

So in each Function, PyTorch first checks if its inputs needs tracing, and performs trace_apply() as implemented here. You can see the node being created and appended to the graph:

// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
    var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
// ...
for (auto& input: inputs) {
    this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);

My best guess here is that every Function object registers itself and its inputs (if needed) upon execution. Every non-functional calls (eg. variable.dot()) simply defers to the corresponding function, so this still applies.

NOTE: I don't take part in PyTorch's development and is in no way an expert on its architecture. Any corrections or addition would be welcomed.

like image 138
Mach_Zero Avatar answered Nov 01 '22 00:11

Mach_Zero