Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why there are different output between model.forward(input) and model(input)

I'm using pytorch to build a simple model like VGG16,and I have overloaded the function forward in my model.

I found everyone tends to use model(input) to get the output rather than model.forward(input), and I am interested in the difference between them. I try to input the same data, but the result is different. I'm confused.

I have output the layer_weight before I input data, the weight not be changed, and I know when we using model(input) it using __call__ function, and this function will call model.forward.

   vgg = VGG()
   vgg.double()
   for layer in vgg.modules():
      if isinstance(layer,torch.nn.Linear):
         print(layer.weight)
   print("   use model.forward(input)     ")
   result = vgg.forward(array)

   for layer in vgg.modules():
     if isinstance(layer,torch.nn.Linear):
       print(layer.weight) 
   print("   use model(input)     ")
   result_2 = vgg(array)
   print(result)
   print(result_2)

output:

    Variable containing:1.00000e-02 *
    -0.2931  0.6716 -0.3497 -2.0217 -0.0764  1.2162  1.4983 -1.2881
    [torch.DoubleTensor of size 1x8]

    Variable containing:
    1.00000e-02 *
    0.5302  0.4494 -0.6866 -2.1657 -0.9504  1.0211  0.8308 -1.1665
    [torch.DoubleTensor of size 1x8]
like image 856
KaguyaSan Avatar asked Mar 25 '19 13:03

KaguyaSan


1 Answers

model.forward just calls the forward operations as you mention but __call__ does a little extra.

If you dig into the code of nn.Module class you will see __call__ ultimately calls forward but internally handles the forward or backward hooks and manages some states that pytorch allows. When calling a simple model like just an MLP, it may not be really needed but more complex model like spectral normalization layers have hooks and therefore you should use model(.) signature as much as possible unless you explicitly just want to call model.forward

Also see Calling forward function without .forward()

In this case, however, the difference may be due to some dropout layer, you should call vgg.eval() to make sure all the stochasticity in network is turned off before comparing the outputs.

like image 109
Umang Gupta Avatar answered Sep 28 '22 08:09

Umang Gupta