Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Concatenate rows in alternate order

Tags:

pytorch

I am trying to code up the positional encoding in the transformers paper. In order to do so I need to do an operation similar to the following:

a = torch.arange(20).reshape(4,5) 
b = a * 2
c = torch.cat([torch.stack([a_row,b_row]) for a_row, b_row in zip(a,b)])

I feel like there might be a faster way to do the above? perhaps by adding a dimension on to a and b?

like image 991
sachinruk Avatar asked Oct 24 '25 05:10

sachinruk


2 Answers

I would simply use the assignment operator for this:

c = torch.zeros(8, 5)
c[::2, :] = a   # Index every second row, starting from 0
c[1::2, :] = b  # Index every second row, starting from 1 

When timing the two solutions, I used the following:

import timeit
import torch
a = torch.arange(20).reshape(4,5) 
b = a * 2

suggested = timeit.timeit("c = torch.cat([torch.stack([a_row, b_row]) for a_row, b_row in zip (a, b)])", 
                          setup="import torch; from __main__ import a, b", number=10000)
print(suggested/10000)
# 4.5105120493099096e-05

improved = timeit.timeit("c = torch.zeros(8, 5); c[::2, :] = a; c[1::2, :] = b", 
                         setup="import torch; from __main__ import a, b", number=10000)
print(improved/10000)
# 2.1489459509029985e-05

The second approach takes consistently less (approximately half) the time, even though a single iteration is still very fast. Of course, you would have to test this for your actual tensor sizes, but that is the most straightforward solution I could come up with. Can't wait to see if anyone has some nifty low-level solution for this that is even faster!

Also, keep in mind that I did not time the creation of b, assuming that the tensors you want to interweave are already given.

like image 63
dennlinger Avatar answered Oct 26 '25 23:10

dennlinger


So turns out simple concatenation and reshaping does the trick:

c = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

When I timed it with the following it was about 2.3x faster than @dennlinger's answer:

improved2 = timeit.timeit("c = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])",  
                          setup="import torch; from __main__ import a, b", 
                          number=10000) 
print(improved2/10000) 
# 7.253780400003507e-06
print(improved / improved2)
# 2.3988091506044955
like image 20
sachinruk Avatar answered Oct 27 '25 00:10

sachinruk