I am trying to update/change the parameters of a neural net model and then having the forward pass of the updated neural net be in the computation graph (no matter how many changes/updates we do).
I tried this idea but whenever I do it pytorch sets my updated tensors (inside the model) to be leafs, which kills the flow of gradients to the networks I want to receive gradients. It kills the flow of gradients because leaf nodes are not part of the computation graph the way I want them to be (since they aren't truly leafs).
I've tried multiple things but nothing seems to work. I created a dummy code that is self contained that prints the gradients of the networks I desire to have gradients:
import torch
import torch.nn as nn
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in loss_net.named_parameters():
print(f'name = {name}')
print(w.size())
hidden = updater_net(hidden).view(1)
print(hidden.size())
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
print(wt.size())
new_params[name] = wt
#del loss_net.fc0.weight
#setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
#setattr(loss_net.fc0, 'weight', wt)
#loss_net.fc0.weight = wt
#loss_net.fc0.weight = nn.Parameter( wt )
##
loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
if anyone knows how to do this please give me a ping...I set the the number of times to update to be 2 because the update operation should be in the computation graph an arbitrary number of times...so it MUST work for 2.
Strongly related post:
Cross-posted:
DOESNT WORK PROPERLY cuz the named parameter modules get deleted.
Seems this works:
import torch
import torch.nn as nn
from torchviz import make_dot
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
def del_attr(obj, names):
if len(names) == 1:
delattr(obj, names[0])
else:
del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
if len(names) == 1:
setattr(obj, names[0], val)
else:
set_attr(getattr(obj, names[0]), names[1:], val)
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in list(loss_net.named_parameters()):
hidden = updater_net(hidden).view(1)
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
del_attr(loss_net, name.split("."))
set_attr(loss_net, name.split("."), wt)
##
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
print(f'loss_net.fc0.weight.is_leaf = {loss_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}') # None because this is not a leaf, it is overriden in the for loop above.
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
make_dot(loss_val)
output:
updater_net.fc0.weight.is_leaf = True
i = 0
i = 1
updater_net.fc0.weight.is_leaf = True
loss_net.fc0.weight.is_leaf = False
-- params that dont matter if they have gradients --
loss_net.grad = None
-- params we want to have gradients --
hidden.grad = None
updater_net.fc0.weight.grad = tensor([[0.7152]])
updater_net.fc0.bias.grad = tensor([-7.4249])
Acknowledgement: mighty albanD from pytorch team: https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076/9?u=pinocchio
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With