Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch weak_script_method decorator

Tags:

python

pytorch

I came across some code in an introduction to Word2Vec and PyTorch that I'm not quite familiar with. I haven't seen this type of code structure before.

>>> import torch
>>> from torch import nn

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)

tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

I'm a little confused about the following line of code.

>>> embedding(input)

I may have inadvertently ignored this syntax in the past, but I don't recall seeing a variable being passed to a class instance before? Referring to the PyTorch documentation where Class Embedding() is defined, is this behaviour enabled with decorator @weak_script_method wrapping def forward()? The code below suggests this may be the case?

>>> torch.manual_seed(2)
>>> torch.eq(embedding(input), embedding.forward(input)).all()

tensor(1, dtype=torch.uint8)

Why is the use of decorator @weak_script_method preferable in this case?

like image 567
Josmoor98 Avatar asked Mar 24 '26 16:03

Josmoor98


1 Answers

No, @weak_script_method has nothing to do with it. embedding(input) follows the Python function call syntax, which can be used with both "traditional" functions and with objects which define the __call__(self, *args, **kwargs) magic function. So this code

class Greeter:
    def __init__(self, name):
        self.name = name

    def __call__(self, name):
        print('Hello to ' + name + ' from ' + self.name + '!')

greeter = Greeter('Jatentaki')
greeter('EBB')

will result in Hello to EBB from Jatentaki! being printed to stdout. Similarly, Embedding is an object which you construct by telling it how many embeddings it should contain, what should be their dimensionality, etc, and then, after it is constructed, you can call it like a function, to retrieve the desired part of the embedding.

The reason you do not see __call__ in nn.Embedding source is that it subclasses nn.Module, which provides an automatic __call__ implementation which delegates to forward and calls some extra stuff before and afterwards (see the documentation). So, calling module_instance(arguments) is roughly equivalent to calling module_instance.forward(arguments).

The @weak_script_method decorator has little to do with it. It is related to jit compatibility, and @weak_script_method is a variant of @script_method designed for internal use in PyTorch - the only message for you should be that nn.Embedding is compatible with jit, if you wanted to use it.

like image 57
Jatentaki Avatar answered Mar 27 '26 06:03

Jatentaki



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!