In PyTorch, the Tensor
class has a grad_fn
attribute. This references the operation used to obtain the tensor: for instance, if a = b + 2
, a.grad_fn
will be AddBackward0
. But what does "reference" mean exactly?
Inspecting AddBackward0
using inspect.getmro(type(a.grad_fn))
will state that the only base class of AddBackward0
is object
. Additionally, the source code for this class (and in fact, any other class which might be encountered in grad_fn
) is nowhere to be found in the source code!
All of this leads me to the following questions:
grad_fn
and how is it called during back-propagation?grad_fn
do not have some sort of common super class, and why is there no source code for them on GitHub?grad_fn
is a function "handle", giving access to the applicable gradient function. The gradient at the given point is a coefficient for adjusting weights during back-propagation.
"Handle" is a general term for an object descriptor, designed to give appropriate access to the object. For instance, when you open a file, open
returns a file handle. When you instantiate a class, the __init__
function returns a handle to the created instance. The handle contains references (usually memory addresses) to the data and functions for the item in question.
It appears as the generic object
class because it's from the underlying implementation in another language, such that it does not map exactly to the Python function
type. PyTorch handles the inter-language call and return. This hand-off is part of the pre-complied (shared-object) run-time system.
Is that enough to clarify what you see?
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With