Consider the following function
def foo(x):
with tf.GradientTape() as tape:
tape.watch(x)
y = x**2 + x + 4
return tape.gradient(y, x)
The call to tape.watch(x)
is necessary if the function is called say as foo(tf.constant(3.14))
, but is not when it is passed in a variable directly, such as foo(tf.Variable(3.14))
.
Now my question is, is the call to tape.watch(x)
safe even in the case when tf.Variable
is passed in directly? Or will some strangness happen due to the variable already being auto-watched and then watched manually again? What is the correct way to write general functions like this that can accept both tf.Tensor
and tf.Variable
?
As far as I know, Variable is the default operation for making a variable, and get_variable is mainly used for weight sharing.
Gradient tapesTensorFlow "records" relevant operations executed inside the context of a tf. GradientTape onto a "tape". TensorFlow then uses that tape to compute the gradients of a "recorded" computation using reverse mode differentiation.
A tf. Variable represents a tensor whose value can be changed by running ops on it. Specific ops allow you to read and modify the values of this tensor.
In TensorFlow the differences between constants and variables are that when you declare some constant, its value can't be changed in the future (also the initialization should be with a value, not with operation). Nevertheless, when you declare a Variable, you can change its value in the future with tf.
It should be safe. On the one hand, the documentation of tf.GradientTape.watch
says:
Ensures that
tensor
is being traced by this tape.
"Ensures" seems to imply that it will make sure it is traced in case it is not. In fact, the documentation does not give any indication that using it twice over the same object should be a problem (although it wouldn't hurt if they made that explicit).
But in any case, we can dig into the source code to check. In the end, calling watch
on a variable (the answer ends up the same if it's not a variable but the path diverges slightly) comes down to the WatchVariable
method of a GradientTape
class in C++:
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
tensorflow::int64 id = FastTensorId(handle.get());
if (!PyErr_Occurred()) {
this->Watch(id);
}
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
if (insert_result.second) {
// Only increment the reference count if we aren't already watching this
// variable.
Py_INCREF(v);
}
}
The second half of the method shows that the watched variable is added to watched_variables_
, which is a std::set
, so adding again something will not do anything. This is actually checked later to make sure Python reference counting is correct. The first half basically calls Watch
:
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
tensor_tape_
is a map (specifically a tensorflow::gtl:FlatMap
, pretty much the same as a standard C++ map), so if tensor_id
is already there this will have no effect.
So, even though it is not explicitly stated, everything suggests there should be no issues with it.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With