Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In PyTorch, what exactly does the grad_fn attribute store and how is it used?

In PyTorch, the Tensor class has a grad_fn attribute. This references the operation used to obtain the tensor: for instance, if a = b + 2, a.grad_fn will be AddBackward0. But what does "reference" mean exactly?

Inspecting AddBackward0 using inspect.getmro(type(a.grad_fn)) will state that the only base class of AddBackward0 is object. Additionally, the source code for this class (and in fact, any other class which might be encountered in grad_fn) is nowhere to be found in the source code!

All of this leads me to the following questions:

  1. What precisely is stored in grad_fn and how is it called during back-propagation?
  2. How come the objects that get stored in grad_fn do not have some sort of common super class, and why is there no source code for them on GitHub?
like image 788
David Cian Avatar asked Nov 07 '22 02:11

David Cian


1 Answers

grad_fn is a function "handle", giving access to the applicable gradient function. The gradient at the given point is a coefficient for adjusting weights during back-propagation.

"Handle" is a general term for an object descriptor, designed to give appropriate access to the object. For instance, when you open a file, open returns a file handle. When you instantiate a class, the __init__ function returns a handle to the created instance. The handle contains references (usually memory addresses) to the data and functions for the item in question.

It appears as the generic object class because it's from the underlying implementation in another language, such that it does not map exactly to the Python function type. PyTorch handles the inter-language call and return. This hand-off is part of the pre-complied (shared-object) run-time system.

Is that enough to clarify what you see?

like image 127
Prune Avatar answered Nov 12 '22 17:11

Prune