Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How exactly does torch / np einsum work internally

This is a query regarding the internal working of torch.einsum in the GPU. I know how to use einsum. Does it perform all possible matrix multiplications, and just pick out the relevant ones, or does it perform only the required computation?

For example, consider two tensors a and b, of shape (N,P), and I wish to find the dot product of each corresponding tensor ni, of shape (1,P). Using einsum, the code is:

torch.einsum('ij,ij->i',a,b)

Without using einsum, another way to obtain the output is :

torch.diag(a @ b.t())

Now, the second code is supposed to perform significantly more computations than the first one (eg if N = 2000, it performs 2000 times more computation). However, when I try to time the two operations, they take roughly the same amount of time to complete, which begs the question. Does einsum perform all combinations (like the second code), and picks out the relevant values?

Sample Code to test:

import time
import torch
for i in range(100):
  a = torch.rand(50000, 256).cuda()
  b = torch.rand(50000, 256).cuda()

  t1 = time.time()
  val = torch.diag(a @ b.t())
  t2 = time.time()
  val2 = torch.einsum('ij,ij->i',a,b)
  t3 = time.time()
  print(t2-t1,t3-t2, torch.allclose(val,val2))
like image 820
OlorinIstari Avatar asked Oct 15 '25 15:10

OlorinIstari


1 Answers

It probably has to do with the fact that the GPU can parallelize the computation of a @ b.t(). This means that the GPU doesn't actually have to wait for each row-column multiplication computation to finish to compute then next multiplication. If you check on CPU then you see that torch.diag(a @ b.t()) is significantly slower than torch.einsum('ij,ij->i',a,b) for large a and b.

like image 150
Gil Pinsky Avatar answered Oct 17 '25 05:10

Gil Pinsky