Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to assign a new value to a pytorch Variable without breaking backpropagation?

Tags:

pytorch

I have a pytorch variable that is used as a trainable input for a model. At some point I need to manually reassign all values in this variable.

How can I do that without breaking the connections with the loss function?

Suppose the current values are [1.2, 3.2, 43.2] and I simply want them to become [1,2,3].


Edit

At the time I asked this question, I hadn't realized that PyTorch doesn't have a static graph as Tensorflow or Keras do.

In PyTorch, the training loop is made manually and you need to call everything in each training step. (There isn't the notion of placeholder + static graph for later feeding data).

Consequently, we can't "break the graph", since we will use the new variable to perform all the further computations again. I was worried about a problem that happens in Keras, not in PyTorch.

like image 579
Daniel Möller Avatar asked Dec 17 '18 16:12

Daniel Möller


People also ask

How do you change the value of tensor PyTorch?

we can modify a tensor by using the assignment operator. Assigning a new value in the tensor will modify the tensor with the new value. Import the torch libraries and then create a PyTorch tensor.

What does Requires_grad do in PyTorch?

Setting requires_grad Parameter , that allows for fine-grained exclusion of subgraphs from gradient computation. It takes effect in both the forward and backward passes: During the forward pass, an operation is only recorded in the backward graph if at least one of its input tensors require grad.

What does * do in PyTorch?

For . view() pytorch expects the new shape to be provided by individual int arguments (represented in the doc as *shape ). The asterisk ( * ) can be used in python to unpack a list into its individual elements, thus passing to view the correct form of input arguments it expects.

What is Grad_fn PyTorch?

grad_fn attribute that references a function that has created a function (except for Tensors created by the user - these have None as . grad_fn ).


1 Answers

You can use the data attribute of tensors to modify the values, since modifications on data do not affect the graph.
So the graph will still be intact and modifications of the data attribute itself have no influence on the graph. (Operations and changes on data are not tracked by autograd and thus not present in the graph)

Since you haven't given an example, this example is based on your comment statement:
'Suppose I want to change the weights of a layer.'
I used normal tensors here, but this works the same for weight.data and bias.data attributes of a layers.

Here is a short example:

import torch
import torch.nn.functional as F



# Test 1, random vector with CE
w1 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w1, torch.tensor([1]))
loss.backward()
print('w1.data', w1)
print('w1.grad', w1.grad)
print()

# Test 2, replacing values of w2 with w1, before CE
# to make sure that everything is exactly like in Test 1 after replacing the values
w2 = torch.zeros(1, 3, requires_grad=True)
w2.data = w1.data
loss = F.cross_entropy(w2, torch.tensor([1]))
loss.backward()
print('w2.data', w2)
print('w2.grad', w2.grad)
print()

# Test 3, replace data after computation
w3 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w3, torch.tensor([1]))
# setting values
# the graph of the previous computation is still intact as you can in the below print-outs
w3.data = w1.data
loss.backward()

# data were replaced with values from w1
print('w3.data', w3)
# gradient still shows results from computation with w3
print('w3.grad', w3.grad)

Output:

w1.data tensor([[ 0.9367,  0.6669,  0.3106]])
w1.grad tensor([[ 0.4351, -0.6678,  0.2326]])

w2.data tensor([[ 0.9367,  0.6669,  0.3106]])
w2.grad tensor([[ 0.4351, -0.6678,  0.2326]])

w3.data tensor([[ 0.9367,  0.6669,  0.3106]])
w3.grad tensor([[ 0.3179, -0.7114,  0.3935]])

The most interesting part here is w3. At the time backward is called the values are replaced by values of w1.
But the gradients are calculated based on the CE-function with values of original w3. The replaced values have no effect on the graph. So the graph connection is not broken, replacing had no influence on graph. I hope this is what you were looking for!

like image 74
MBT Avatar answered Oct 12 '22 19:10

MBT