Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to batch convert sentence lengths to masks in PyTorch?

Tags:

nlp

pytorch

For example, from

lens = [3, 5, 4]

we want to get

mask = [[1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0]]

Both of which are torch.LongTensors.

like image 848
0472037 Avatar asked Nov 20 '18 23:11

0472037


Video Answer


2 Answers

One way that I found is:

torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)

Please share if there are better ways!

like image 108
0472037 Avatar answered Oct 20 '22 05:10

0472037


Just to provide a bit of explanation to the answer of @ypc (cannot comment due to lack of reputation):

torch.arange(max_len)[None, :] < lens[:, None]

In a word, the answer uses broadcasting mechanism to implicitly expand the tensor, as done in the accepted answer. Step-by-step:

  1. torch.arange(max_len) gives you [0, 1, 2, 3, 4];

  2. adding [None, :] appends 0th dimension to the tensor, making its shape (1, 5), which gives you [[0, 1, 2, 3, 4]];

  3. similarly, lens[:, None] appends 1st dimension to the tensor lens, making its shape (3, 1), that is [[3], [5], [4]];

  4. By comparing (or doing anything like +,-,*,/, etc) a tensor of (1, 5) and (3, 1), following the rule of broadcasting, the resulting tensor will be of shape (3, 5), and the result values will be result[i, j] = (j < lens[i]).

like image 30
Nicholas Zhi Avatar answered Oct 20 '22 04:10

Nicholas Zhi