Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get around in place operation error if index leaf variable for gradient update?

I am encountering In place operation error when I am trying to index a leaf variable to update gradients with customized Shrink function. I cannot work around it. Any help is highly appreciated!

import torch.nn as nn
import torch
import numpy as np
from torch.autograd import Variable, Function

# hyper parameters
batch_size = 100 # batch size of images
ld = 0.2 # sparse penalty
lr = 0.1 # learning rate

x = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,10,10))), requires_grad=False)  # original

# depends on size of the dictionary, number of atoms.
D = Variable(torch.from_numpy(np.random.normal(0,1,(500,10,10))), requires_grad=True)

# hx sparse representation
ht = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,500,1,1))), requires_grad=True)

# Dictionary loss function
loss = nn.MSELoss()

# customized shrink function to update gradient
shrink_ht = lambda x: torch.stack([torch.sign(i)*torch.max(torch.abs(i)-lr*ld,0)[0] for i in x])

### sparse reprsentation optimizer_ht single image.
optimizer_ht = torch.optim.SGD([ht], lr=lr, momentum=0.9) # optimizer for sparse representation

## update for the batch
for idx in range(len(x)):
    optimizer_ht.zero_grad() # clear up gradients
    loss_ht = 0.5*torch.norm((x[idx]-(D*ht[idx]).sum(dim=0)),p=2)**2
    loss_ht.backward() # back propogation and calculate gradients
    optimizer_ht.step() # update parameters with gradients
    ht[idx] = shrink_ht(ht[idx])  # customized shrink function.

RuntimeError Traceback (most recent call last) in ()
15 loss_ht.backward() # back propogation and calculate gradients
16 optimizer_ht.step() # update parameters with gradients
—> 17 ht[idx] = shrink_ht(ht[idx]) # customized shrink function.
18
19

/home/miniconda3/lib/python3.6/site-packages/torch/autograd/variable.py in setitem(self, key, value)
85 return MaskedFill.apply(self, key, value, True)
86 else:
—> 87 return SetItem.apply(self, key, value)
88
89 def deepcopy(self, memo):

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

Specifically, this line of code below seems give error as it index and update leaf variable at the same time.

ht[idx] = shrink_ht(ht[idx])  # customized shrink function.

Thanks.

W.S.

like image 884
W.S. Avatar asked Mar 07 '18 21:03

W.S.


3 Answers

I just found: In order to update the variable, it needs to be ht.data[idx] instead of ht[idx]. We can use .data to access the tensor directly.

like image 160
W.S. Avatar answered Nov 07 '22 13:11

W.S.


The problem comes from the fact that ht requires grad:

ht = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,500,1,1))), requires_grad=True)

And with variables that require grads, you are not allowed by pytorch to assign values to (slices) of them. You can't do:

ht[idx] = some_tensor

That means that you will need to find another way to do your customized shrink function using built in pytorch functions like squeeze, unsqueeze, etc.

Another option is to assign your shrink_ht(ht[idx]) slices to another variable or tensor which does not require grads.

like image 42
patapouf_ai Avatar answered Nov 07 '22 15:11

patapouf_ai


Using ht.data[idx] is OK here, but the new convention is to explicitly use torch.no_grad(), such as:

with torch.no_grad(): 
    ht[idx] = shrink_ht(ht[idx])

Note that there is no gradient for this in place operation. In other words, the gradients only backward to the shrunk values of ht, not to the unshrunk values of ht.

like image 42
THN Avatar answered Nov 07 '22 13:11

THN