Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does the copy_initial_weights documentation mean in the higher library for Pytorch?

I was trying to use the higher library for meta-learning and I was having issues understanding what the copy_initial_weights mean. The docs say:

copy_initial_weights – if true, the weights of the patched module are copied to form the initial weights of the patched module, and thus are not part of the gradient tape when unrolling the patched module. If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.

but that doesn't make much sense to me because of the following:

For example, "the weights of the patched module are copied to form the initial weights of the patched module" doesn't make sense to me because when the context manager is initiated a patched module does not exist yet. So it is unclear what we are copying from and to where (and why copying is something we want to do).

Also, "unrolling the patched module" does not make sense to me. We usually unroll a computaiton graph caused by a for loop. A patched module is just a neural net that has been modified by this library. Unrolling is ambiguous.

Also, there isn't a technical definition for "gradient tape".

Also, when describing what false is, saying that it's useful for MAML isn't actually useful because it doesn't even hint why it's useful for MAML.

Overall, it's impossible to use the context manager.

Any explanations and examples of what the that flag does in more precise terms would be really valuable.


Related:

  • gitissue: https://github.com/facebookresearch/higher/issues/30
  • new gitissue: https://github.com/facebookresearch/higher/issues/54
  • pytorch forum: https://discuss.pytorch.org/t/why-does-maml-need-copy-initial-weights-false/70387
  • pytorch forum: https://discuss.pytorch.org/t/what-does-copy-initial-weights-do-in-the-higher-library/70384
  • important question related to this on how the fmodel parameters are copied so that the optimizers work (and the use of deep copy): Why does higher need to deep copy the parameters of the base model to create a functional model?
like image 481
Charlie Parker Avatar asked Feb 20 '20 00:02

Charlie Parker


2 Answers

Short version

Call to higher.innerloop_ctx with model as argument create temporary patched model and unrolled optimizer for that model: (fmodel, diffopt). It is expected that in the inner loop fmodel will iteratively receive some input, compute output and loss and then diffopt.step(loss) will be called. Each time diffopt.step is called fmodel will create next version of parameters fmodel.parameters(time=T) which is a new tensor computed using previous ones (with the full graph allowing to compute gradients through the process). If at any point user calls backward on any tensor, regular pytorch gradient computation/accumulation will start in a way allowing gradients to propagate to e.g. optimizer's parameters (such as lr, momentum - if they were passed as tensors requiring gradients to higher.innerloop_ctx using override).

Creation-time version of fmodel's parameters fmodel.parameters(time=0) is a copy of original model parameters. If copy_initial_weights=True provided (default) then fmodel.parameters(time=0) will be a clone+detach'ed version of model's parameters (i.e. will preserve values, but will severe all connections to the original model). If copy_initial_weights=False provided, then fmodel.parameters(time=0) will be clone'd version of model's parameters and thus will allow gradients to propagate to original model's parameters (see pytorch doc on clone).

Terminology clarifications

  • gradient tape here is referring to the graph pytorch uses to go through computations to propagate gradients to all leaf tensors requiring gradients. If at some point you cut the link to some leaf tensor requiring parameters (e.g. how it is done for fnet.parameters() for copy_initial_weights=True case) then the original model.parameters() won't be "on gradient tape" anymore for your meta_loss.backward() computation.

  • unrolling the patched module here refers to the part of meta_loss.backward() computation when pytorch is going through all fnet.parameters(time=T) starting from the latest and ending with the earliest (higher doesn't control the process - this is just regular pytorch gradient computation, higher is just in charge of how these new time=T parameters are being created from previous ones each time diffopt.step is called and how fnet is always using the latest ones for forward computation).

Long version

Let's start from the beginning. Main functionality (only functionality, really) of higher library is unrolling of a model's parameter optimization in a differentiable manner. It can come either in the form of directly using differentiable optimizer through e.g. higher.get_diff_optim as in this example or in the form of higher.innerloop_ctx as in this example.

The option with higher.innerloop_ctx is wrapping the creation of "stateless" model fmodel from existing model for you and gives you an "optimizer" diffopt for this fmodel. So as summarized in the README.md of higher it allows you to switch from:

model = MyModel()
opt = torch.optim.Adam(model.parameters())

for xs, ys in data:
    opt.zero_grad()
    logits = model(xs)
    loss = loss_function(logits, ys)
    loss.backward()
    opt.step()

to

model = MyModel()
opt = torch.optim.Adam(model.parameters())

with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
    for xs, ys in data:
        logits = fmodel(xs)  # modified `params` can also be passed as a kwarg
        loss = loss_function(logits, ys)  # no need to call loss.backwards()
        diffopt.step(loss)  # note that `step` must take `loss` as an argument!

    # At the end of your inner loop you can obtain these e.g. ...
    grad_of_grads = torch.autograd.grad(
        meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))

The difference between training model and doing diffopt.step to update fmodel is that fmodel is not updating the parameters in-place as opt.step() in the original part would do. Instead each time diffopt.step is called new versions of parameters are created in such a way, that fmodel would use new ones for the next step, but all previous ones are still preserved.

I.e. fmodel starts with only fmodel.parameters(time=0) available, but after you called diffopt.step N times you can ask fmodel to give you fmodel.parameters(time=i) for any i up to N inclusive. Notice that fmodel.parameters(time=0) doesn't change in this process at all, just every time fmodel is applied to some input it will use the latest version of parameters it currently has.

Now, what exactly is fmodel.parameters(time=0)? It is created here and depends on copy_initial_weights. If copy_initial_weights==True then fmodel.parameters(time=0) are clone'd and detach'ed parameters of model. Otherwise they are only clone'd, but not detach'ed!

That means that when we do meta-optimization step, the original model's parameters will actually accumulate gradients if and only if copy_initial_weights==False. And in MAML we want to optimize model's starting weights so we actually do need to get gradients from meta-optimization step.

I think one of the issues here is that higher lack of simpler toy examples to demonstrate what is going on, instead rushing to show more serious things as the examples. So let me try to fill that gap here and demonstrate what is going on using the simplest toy example I could come up with (model with 1 weight which multiplies input by that weight):

import torch
import torch.nn as nn
import torch.optim as optim
import higher
import numpy as np

np.random.seed(1)
torch.manual_seed(3)
N = 100
actual_multiplier = 3.5
meta_lr = 0.00001
loops = 5 # how many iterations in the inner loop we want to do

x = torch.tensor(np.random.random((N,1)), dtype=torch.float64) # features for inner training loop
y = x * actual_multiplier # target for inner training loop
model = nn.Linear(1, 1, bias=False).double() # simplest possible model - multiple input x by weight w without bias
meta_opt = optim.SGD(model.parameters(), lr=meta_lr, momentum=0.)


def run_inner_loop_once(model, verbose, copy_initial_weights):
    lr_tensor = torch.tensor([0.3], requires_grad=True)
    momentum_tensor = torch.tensor([0.5], requires_grad=True)
    opt = optim.SGD(model.parameters(), lr=0.3, momentum=0.5)
    with higher.innerloop_ctx(model, opt, copy_initial_weights=copy_initial_weights, override={'lr': lr_tensor, 'momentum': momentum_tensor}) as (fmodel, diffopt):
        for j in range(loops):
            if verbose:
                print('Starting inner loop step j=={0}'.format(j))
                print('    Representation of fmodel.parameters(time={0}): {1}'.format(j, str(list(fmodel.parameters(time=j)))))
                print('    Notice that fmodel.parameters() is same as fmodel.parameters(time={0}): {1}'.format(j, (list(fmodel.parameters())[0] is list(fmodel.parameters(time=j))[0])))
            out = fmodel(x)
            if verbose:
                print('    Notice how `out` is `x` multiplied by the latest version of weight: {0:.4} * {1:.4} == {2:.4}'.format(x[0,0].item(), list(fmodel.parameters())[0].item(), out[0].item()))
            loss = ((out - y)**2).mean()
            diffopt.step(loss)

        if verbose:
            # after all inner training let's see all steps' parameter tensors
            print()
            print("Let's print all intermediate parameters versions after inner loop is done:")
            for j in range(loops+1):
                print('    For j=={0} parameter is: {1}'.format(j, str(list(fmodel.parameters(time=j)))))
            print()

        # let's imagine now that our meta-learning optimization is trying to check how far we got in the end from the actual_multiplier
        weight_learned_after_full_inner_loop = list(fmodel.parameters())[0]
        meta_loss = (weight_learned_after_full_inner_loop - actual_multiplier)**2
        print('  Final meta-loss: {0}'.format(meta_loss.item()))
        meta_loss.backward() # will only propagate gradient to original model parameter's `grad` if copy_initial_weight=False
        if verbose:
            print('  Gradient of final loss we got for lr and momentum: {0} and {1}'.format(lr_tensor.grad, momentum_tensor.grad))
            print('  If you change number of iterations "loops" to much larger number final loss will be stable and the values above will be smaller')
        return meta_loss.item()

print('=================== Run Inner Loop First Time (copy_initial_weights=True) =================\n')
meta_loss_val1 = run_inner_loop_once(model, verbose=True, copy_initial_weights=True)
print("\nLet's see if we got any gradient for initial model parameters: {0}\n".format(list(model.parameters())[0].grad))

print('=================== Run Inner Loop Second Time (copy_initial_weights=False) =================\n')
meta_loss_val2 = run_inner_loop_once(model, verbose=False, copy_initial_weights=False)
print("\nLet's see if we got any gradient for initial model parameters: {0}\n".format(list(model.parameters())[0].grad))

print('=================== Run Inner Loop Third Time (copy_initial_weights=False) =================\n')
final_meta_gradient = list(model.parameters())[0].grad.item()
# Now let's double-check `higher` library is actually doing what it promised to do, not just giving us
# a bunch of hand-wavy statements and difficult to read code.
# We will do a simple SGD step using meta_opt changing initial weight for the training and see how meta loss changed
meta_opt.step()
meta_opt.zero_grad()
meta_step = - meta_lr * final_meta_gradient # how much meta_opt actually shifted inital weight value
meta_loss_val3 = run_inner_loop_once(model, verbose=False, copy_initial_weights=False)

meta_loss_gradient_approximation = (meta_loss_val3 - meta_loss_val2) / meta_step

print()
print('Side-by-side meta_loss_gradient_approximation and gradient computed by `higher` lib: {0:.4} VS {1:.4}'.format(meta_loss_gradient_approximation, final_meta_gradient))

Which produces this output:

=================== Run Inner Loop First Time (copy_initial_weights=True) =================

Starting inner loop step j==0
    Representation of fmodel.parameters(time=0): [tensor([[-0.9915]], dtype=torch.float64, requires_grad=True)]
    Notice that fmodel.parameters() is same as fmodel.parameters(time=0): True
    Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * -0.9915 == -0.4135
Starting inner loop step j==1
    Representation of fmodel.parameters(time=1): [tensor([[-0.1217]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    Notice that fmodel.parameters() is same as fmodel.parameters(time=1): True
    Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * -0.1217 == -0.05075
Starting inner loop step j==2
    Representation of fmodel.parameters(time=2): [tensor([[1.0145]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    Notice that fmodel.parameters() is same as fmodel.parameters(time=2): True
    Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * 1.015 == 0.4231
Starting inner loop step j==3
    Representation of fmodel.parameters(time=3): [tensor([[2.0640]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    Notice that fmodel.parameters() is same as fmodel.parameters(time=3): True
    Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * 2.064 == 0.8607
Starting inner loop step j==4
    Representation of fmodel.parameters(time=4): [tensor([[2.8668]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    Notice that fmodel.parameters() is same as fmodel.parameters(time=4): True
    Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * 2.867 == 1.196

Let's print all intermediate parameters versions after inner loop is done:
    For j==0 parameter is: [tensor([[-0.9915]], dtype=torch.float64, requires_grad=True)]
    For j==1 parameter is: [tensor([[-0.1217]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    For j==2 parameter is: [tensor([[1.0145]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    For j==3 parameter is: [tensor([[2.0640]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    For j==4 parameter is: [tensor([[2.8668]], dtype=torch.float64, grad_fn=<AddBackward0>)]
    For j==5 parameter is: [tensor([[3.3908]], dtype=torch.float64, grad_fn=<AddBackward0>)]

  Final meta-loss: 0.011927987982895929
  Gradient of final loss we got for lr and momentum: tensor([-1.6295]) and tensor([-0.9496])
  If you change number of iterations "loops" to much larger number final loss will be stable and the values above will be smaller

Let's see if we got any gradient for initial model parameters: None

=================== Run Inner Loop Second Time (copy_initial_weights=False) =================

  Final meta-loss: 0.011927987982895929

Let's see if we got any gradient for initial model parameters: tensor([[-0.0053]], dtype=torch.float64)

=================== Run Inner Loop Third Time (copy_initial_weights=False) =================

  Final meta-loss: 0.01192798770078706

Side-by-side meta_loss_gradient_approximation and gradient computed by `higher` lib: -0.005311 VS -0.005311
like image 99
Alexander Pivovarov Avatar answered Oct 20 '22 08:10

Alexander Pivovarov


I think it's more or less clear what this means now to me.

First I'd like to make some notation clear, specially with respect to indices wrt inner time step and outer time step (also known as episodes):

W^<inner_i, outer_i> = denotes the value a tensor has at time step inner_i, outer_i.

At the beginning of training a neural net has params:

W^<0,0>

and are held inside it's module. For the sake of explanation the specific tensor (for the base model) will be denoted:

W = the weight holding the weights for the model. This can be thought as the initialization of the model.

and will be updated with with an in-place operation (this is important since W is the placeholder for all W^<0,outer_i> for all outer step values during "normal" meta-learning) by the outer optimizer. I want to emphasize that W is the tensor for the normal Pytorch neural net base model. By changing this in-place with an outer optimizer (like Adam) we are effectively training the initialization. The outer optimizer will use the gradients wrt this tensor to do the update through the whole unrolled inner loop process.

When we say copy_initial_weights=False we mean that we will have a gradient path directly to W with whatever value it currently has. Usually the context manager is done before a inner loop after an outer step has been done so W will have W^<0,outer_i> for the current step. In particular the code that does this is this one for copy_initial_weight=False:

params = [ p.clone() if device is None else p.clone().to(device) for p in module.parameters() ]

this might look confusing if you're not familiar with clone but what it's doing is making a copy of the current weight of W. The unusual thing is that clone also remembers the gradient history from the tensor it came from (.clone() is as identity). It's main use it to add an extra layer of safety from the user doing dangerous in-place ops in it's differentiable optimizer. Assuming the user never did anything crazy with in-place ops one could in theory remove the .clone(). the reason this is confusing imho is because "copying in Pytorch" (clinging) does not automatically block gradient flows, which is what a "real" copy would do (i.e. create a 100% totally separate tensor). This is not what clone does and that is not what copy_initial_weights does.

When copy_initial_weights=True what really happens is that the weights are cloned and detached. See the code it eventually runs (here and here):

params = [_copy_tensor(p, safe_copy, device) for p in module.parameters()]

which runs copy tensor (assuming they are doing a safe copy i.e. doing the extra clone):

 t = t.clone().detach().requires_grad_(t.requires_grad)

Note that .detach() does not allocate new memory. It shares the memory with the original tensor, which is why the .clone() is needed to have this op be "safe" (usually wrt in-place ops).

So when copy_initial_weights they are copying and detaching the current value of W. This is usually W^<0,outer_i> if it's doing usual meta-learning in the inner adaptation loop. So the intended semantics of copy_initial_weight is that and the initial_weight they simply mean W. The important thing to note is that the intermediate tensors for the net in the inner loop are not denoted in my notation but they are fmodel.parameters(t=inner_i). Also if things are usually meta-learning we have fmodel.parameters(t=0) = W and it gets update in-place by the outer optimizer.

Note that because of the outer optimizer's in-place op and the freeing of the graphs we never take the derivate Grad_{W^<0,0>} with respect to the initial value of W. Which was something I initially thought we were doing.

like image 26
Charlie Parker Avatar answered Oct 20 '22 09:10

Charlie Parker