I came across some code in an introduction to Word2Vec and PyTorch that I'm not quite familiar with. I haven't seen this type of code structure before.
>>> import torch
>>> from torch import nn
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
I'm a little confused about the following line of code.
>>> embedding(input)
I may have inadvertently ignored this syntax in the past, but I don't recall seeing a variable being passed to a class instance before? Referring to the PyTorch documentation where Class Embedding() is defined, is this behaviour enabled with decorator @weak_script_method wrapping def forward()? The code below suggests this may be the case?
>>> torch.manual_seed(2)
>>> torch.eq(embedding(input), embedding.forward(input)).all()
tensor(1, dtype=torch.uint8)
Why is the use of decorator @weak_script_method preferable in this case?
No, @weak_script_method has nothing to do with it. embedding(input) follows the Python function call syntax, which can be used with both "traditional" functions and with objects which define the __call__(self, *args, **kwargs) magic function. So this code
class Greeter:
def __init__(self, name):
self.name = name
def __call__(self, name):
print('Hello to ' + name + ' from ' + self.name + '!')
greeter = Greeter('Jatentaki')
greeter('EBB')
will result in Hello to EBB from Jatentaki! being printed to stdout. Similarly, Embedding is an object which you construct by telling it how many embeddings it should contain, what should be their dimensionality, etc, and then, after it is constructed, you can call it like a function, to retrieve the desired part of the embedding.
The reason you do not see __call__ in nn.Embedding source is that it subclasses nn.Module, which provides an automatic __call__ implementation which delegates to forward and calls some extra stuff before and afterwards (see the documentation). So, calling module_instance(arguments) is roughly equivalent to calling module_instance.forward(arguments).
The @weak_script_method decorator has little to do with it. It is related to jit compatibility, and @weak_script_method is a variant of @script_method designed for internal use in PyTorch - the only message for you should be that nn.Embedding is compatible with jit, if you wanted to use it.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With