Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding PyTorch einsum

Tags:

I'm familiar with how einsum works in NumPy. A similar functionality is also offered by PyTorch: torch.einsum(). What are the similarities and differences, either in terms of functionality or performance? The information available at PyTorch documentation is rather scanty and doesn't provide any insights regarding this.

like image 834
kmario23 Avatar asked Apr 28 '19 21:04

kmario23


Video Answer


1 Answers

Since the description of einsum is skimpy in torch documentation, I decided to write this post to document, compare and contrast how torch.einsum() behaves when compared to numpy.einsum().

Differences:

  • NumPy allows both small case and capitalized letters [a-zA-Z] for the "subscript string" whereas PyTorch allows only the small case letters [a-z].

  • NumPy accepts nd-arrays, plain Python lists (or tuples), list of lists (or tuple of tuples, list of tuples, tuple of lists) or even PyTorch tensors as operands (i.e. inputs). This is because the operands have only to be array_like and not strictly NumPy nd-arrays. On the contrary, PyTorch expects the operands (i.e. inputs) strictly to be PyTorch tensors. It will throw a TypeError if you pass either plain Python lists/tuples (or its combinations) or NumPy nd-arrays.

  • NumPy supports lot of keyword arguments (for e.g. optimize) in addition to nd-arrays while PyTorch doesn't offer such flexibility yet.

Here are the implementations of some examples both in PyTorch and NumPy:

# input tensors to work with  In [16]: vec Out[16]: tensor([0, 1, 2, 3])  In [17]: aten Out[17]:  tensor([[11, 12, 13, 14],         [21, 22, 23, 24],         [31, 32, 33, 34],         [41, 42, 43, 44]])  In [18]: bten Out[18]:  tensor([[1, 1, 1, 1],         [2, 2, 2, 2],         [3, 3, 3, 3],         [4, 4, 4, 4]]) 

1) Matrix multiplication
PyTorch: torch.matmul(aten, bten) ; aten.mm(bten)
NumPy : np.einsum("ij, jk -> ik", arr1, arr2)

In [19]: torch.einsum('ij, jk -> ik', aten, bten) Out[19]:  tensor([[130, 130, 130, 130],         [230, 230, 230, 230],         [330, 330, 330, 330],         [430, 430, 430, 430]]) 

2) Extract elements along the main-diagonal
PyTorch: torch.diag(aten)
NumPy : np.einsum("ii -> i", arr)

In [28]: torch.einsum('ii -> i', aten) Out[28]: tensor([11, 22, 33, 44]) 

3) Hadamard product (i.e. element-wise product of two tensors)
PyTorch: aten * bten
NumPy : np.einsum("ij, ij -> ij", arr1, arr2)

In [34]: torch.einsum('ij, ij -> ij', aten, bten) Out[34]:  tensor([[ 11,  12,  13,  14],         [ 42,  44,  46,  48],         [ 93,  96,  99, 102],         [164, 168, 172, 176]]) 

4) Element-wise squaring
PyTorch: aten ** 2
NumPy : np.einsum("ij, ij -> ij", arr, arr)

In [37]: torch.einsum('ij, ij -> ij', aten, aten) Out[37]:  tensor([[ 121,  144,  169,  196],         [ 441,  484,  529,  576],         [ 961, 1024, 1089, 1156],         [1681, 1764, 1849, 1936]]) 

General: Element-wise nth power can be implemented by repeating the subscript string and tensor n times. For e.g., computing element-wise 4th power of a tensor can be done using:

# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr) In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten) Out[38]:  tensor([[  14641,   20736,   28561,   38416],         [ 194481,  234256,  279841,  331776],         [ 923521, 1048576, 1185921, 1336336],         [2825761, 3111696, 3418801, 3748096]]) 

5) Trace (i.e. sum of main-diagonal elements)
PyTorch: torch.trace(aten)
NumPy einsum: np.einsum("ii -> ", arr)

In [44]: torch.einsum('ii -> ', aten) Out[44]: tensor(110) 

6) Matrix transpose
PyTorch: torch.transpose(aten, 1, 0)
NumPy einsum: np.einsum("ij -> ji", arr)

In [58]: torch.einsum('ij -> ji', aten) Out[58]:  tensor([[11, 21, 31, 41],         [12, 22, 32, 42],         [13, 23, 33, 43],         [14, 24, 34, 44]]) 

7) Outer Product (of vectors)
PyTorch: torch.ger(vec, vec)
NumPy einsum: np.einsum("i, j -> ij", vec, vec)

In [73]: torch.einsum('i, j -> ij', vec, vec) Out[73]:  tensor([[0, 0, 0, 0],         [0, 1, 2, 3],         [0, 2, 4, 6],         [0, 3, 6, 9]]) 

8) Inner Product (of vectors) PyTorch: torch.dot(vec1, vec2)
NumPy einsum: np.einsum("i, i -> ", vec1, vec2)

In [76]: torch.einsum('i, i -> ', vec, vec) Out[76]: tensor(14) 

9) Sum along axis 0
PyTorch: torch.sum(aten, 0)
NumPy einsum: np.einsum("ij -> j", arr)

In [85]: torch.einsum('ij -> j', aten) Out[85]: tensor([104, 108, 112, 116]) 

10) Sum along axis 1
PyTorch: torch.sum(aten, 1)
NumPy einsum: np.einsum("ij -> i", arr)

In [86]: torch.einsum('ij -> i', aten) Out[86]: tensor([ 50,  90, 130, 170]) 

11) Batch Matrix Multiplication
PyTorch: torch.bmm(batch_tensor_1, batch_tensor_2)
NumPy : np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

# input batch tensors to work with In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3) In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4)   In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)   Out[15]:  tensor([[[  20,   23,   26,   29],          [  56,   68,   80,   92],          [  92,  113,  134,  155],          [ 128,  158,  188,  218]],          [[ 632,  671,  710,  749],          [ 776,  824,  872,  920],          [ 920,  977, 1034, 1091],          [1064, 1130, 1196, 1262]]])  # sanity check with the shapes In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape  Out[16]: torch.Size([2, 4, 4])  # batch matrix multiply using einsum In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2) Out[17]:  tensor([[[  20,   23,   26,   29],          [  56,   68,   80,   92],          [  92,  113,  134,  155],          [ 128,  158,  188,  218]],          [[ 632,  671,  710,  749],          [ 776,  824,  872,  920],          [ 920,  977, 1034, 1091],          [1064, 1130, 1196, 1262]]])  # sanity check with the shapes In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape 

12) Sum along axis 2
PyTorch: torch.sum(batch_ten, 2)
NumPy einsum: np.einsum("ijk -> ij", arr3D)

In [99]: torch.einsum("ijk -> ij", batch_ten) Out[99]:  tensor([[ 50,  90, 130, 170],         [  4,   8,  12,  16]]) 

13) Sum all the elements in an nD tensor
PyTorch: torch.sum(batch_ten)
NumPy einsum: np.einsum("ijk -> ", arr3D)

In [101]: torch.einsum("ijk -> ", batch_ten) Out[101]: tensor(480) 

14) Sum over multiple axes (i.e. marginalization)
PyTorch: torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)

# 8D tensor In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9)) In [104]: nDten.shape Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])  # marginalize out dimension 5 (i.e. "n" here) In [111]: esum = torch.einsum("ijklmnop -> n", nDten) In [112]: esum Out[112]: tensor([  98.6921, -206.0575])  # marginalize out axis 5 (i.e. sum over rest of the axes) In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))  In [115]: torch.allclose(tsum, esum) Out[115]: True 

15) Double Dot Products / Frobenius inner product (same as: torch.sum(hadamard-product) cf. 3)
PyTorch: torch.sum(aten * bten)
NumPy : np.einsum("ij, ij -> ", arr1, arr2)

In [120]: torch.einsum("ij, ij -> ", aten, bten) Out[120]: tensor(1300) 
like image 92
kmario23 Avatar answered Oct 15 '22 02:10

kmario23