Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to invert a PyTorch Embedding?

I have an multi-task encoder/decoder model in PyTorch with a (trainable) torch.nn.Embedding embedding layer at the input.

In one particular task, I'd like to pre-train the model self-supervised (to re-construct masked input data) and use it for inference (to fill in gaps in data).

I guess for training time I can just measure loss as the distance between the input embedding and the output embedding... But for inference, how do I invert an Embedding to reconstruct the proper category/token the output corresponds to? I can't see e.g. a "nearest" function on the Embedding class...

like image 333
dingus Avatar asked Oct 25 '20 12:10

dingus


1 Answers

You can do it quite easily:

import torch

embeddings = torch.nn.Embedding(1000, 100)
my_sample = torch.randn(1, 100)
distance = torch.norm(embeddings.weight.data - my_sample, dim=1)
nearest = torch.argmin(distance)

Assuming you have 1000 tokens with 100 dimensionality this would return nearest embedding based on euclidean distance. You could also use other metrics in similar manner.

like image 174
Szymon Maszke Avatar answered Oct 23 '22 04:10

Szymon Maszke