Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does numpy.tensordot function works step-by-step?

Tags:

python

numpy

I am new to numpy, So I have some problem visualizing the working of the numpy.tensordot() function. According to the documentation of tensordot, the axes are passed in the arguments where axes=0 or 1 represents a normal matrix multiplication whereas axes=2 represents Contraction.

Can somebody please explain on how the multiplication would proceed with the given examples?

Example-1: a=[1,1] b=[2,2] for axes=0,1 why does it throw an error for axes=2?
Example-2: a=[[1,1],[1,1]] b=[[2,2],[2,2]] for axes=0,1,2

like image 863
john mich Avatar asked Dec 17 '22 21:12

john mich


2 Answers

Edit: The initial focus of this answer was on the case where axes is a tuple, specifying one or more axes for each argument. This use allows us to perform variations on the conventional dot, especially for arrays larger than 2d (my answer in the linked question also, https://stackoverflow.com/a/41870980/901925). Axes as scalar is a special case, that gets translated into the tuples version. So at its core it is still a dot product.

axes as tuple

In [235]: a=[1,1]; b=[2,2]

a and b are lists; tensordot turns them into arrays.

In [236]: np.tensordot(a,b,(0,0))
Out[236]: array(4)

Since they are both 1d arrays, we specify the axis values as 0.

If we try to specify 1:

In [237]: np.tensordot(a,b,(0,1))
---------------------------------------------------------------------------
   1282     else:
   1283         for k in range(na):
-> 1284             if as_[axes_a[k]] != bs[axes_b[k]]:
   1285                 equal = False
   1286                 break

IndexError: tuple index out of range

It is checking whether size of axis 0 of a matches the size of axis 1 of b. But since b is 1d, it can't check that.

In [239]: np.array(a).shape[0]
Out[239]: 2
In [240]: np.array(b).shape[1]
IndexError: tuple index out of range

Your second example is 2d arrays:

In [242]: a=np.array([[1,1],[1,1]]); b=np.array([[2,2],[2,2]])

Specifying the last axis of a and first of b (second to the last), produces the conventional matrix (dot) product:

In [243]: np.tensordot(a,b,(1,0))
Out[243]: 
array([[4, 4],
       [4, 4]])
In [244]: a.dot(b)
Out[244]: 
array([[4, 4],
       [4, 4]])

Better diagnostic values:

In [250]: a=np.array([[1,2],[3,4]]); b=np.array([[2,3],[2,1]])
In [251]: np.tensordot(a,b,(1,0))
Out[251]: 
array([[ 6,  5],
       [14, 13]])
In [252]: np.dot(a,b)
Out[252]: 
array([[ 6,  5],
       [14, 13]])

In [253]: np.tensordot(a,b,(0,1))
Out[253]: 
array([[11,  5],
       [16,  8]])
In [254]: np.dot(b,a)      # same numbers, different layout
Out[254]: 
array([[11, 16],
       [ 5,  8]])
In [255]: np.dot(b,a).T
Out[255]: 
array([[11,  5],
       [16,  8]])

Another pairing:

In [256]: np.tensordot(a,b,(0,0))
In [257]: np.dot(a.T,b)

(0,1,2) for axis is plain wrong. The axis parameter should be 2 numbers, or 2 tuples, corresponding to the 2 arguments.

The basic processing in tensordot is to transpose and reshape the inputs so it can then pass the results to np.dot for a conventional (last of a, second to the last of b) matrix product.

axes as scalar

If my reading of tensordot code is right, the axes parameter is converted into two lists with:

def foo(axes):
    try:
        iter(axes)
    except Exception:
        axes_a = list(range(-axes, 0))
        axes_b = list(range(0, axes))
    else:
        axes_a, axes_b = axes
    try:
        na = len(axes_a)
        axes_a = list(axes_a)
    except TypeError:
        axes_a = [axes_a]
        na = 1
    try:
        nb = len(axes_b)
        axes_b = list(axes_b)
    except TypeError:
        axes_b = [axes_b]
        nb = 1

    return axes_a, axes_b

For scalar values, 0,1,2 the results are:

In [281]: foo(0)
Out[281]: ([], [])
In [282]: foo(1)
Out[282]: ([-1], [0])
In [283]: foo(2)
Out[283]: ([-2, -1], [0, 1])

axes=1 is the same as specifying in a tuple:

In [284]: foo((-1,0))
Out[284]: ([-1], [0])

And for 2:

In [285]: foo(((-2,-1),(0,1)))
Out[285]: ([-2, -1], [0, 1])

With my latest example, axes=2 is the same as specifying a dot over all axes of the 2 arrays:

In [287]: np.tensordot(a,b,axes=2)
Out[287]: array(18)
In [288]: np.tensordot(a,b,axes=((0,1),(0,1)))
Out[288]: array(18)

This is the same as doing dot on the flattened, 1d, views of the arrays:

In [289]: np.dot(a.ravel(), b.ravel())
Out[289]: 18

I already demonstrated the conventional dot product for these arrays, the axes=1 case.

axes=0 is the same as axes=((),()), no summation axes for the 2 arrays:

In [292]: foo(((),()))
Out[292]: ([], [])

np.tensordot(a,b,((),())) is the same as np.tensordot(a,b,axes=0)

It's the -2 in the foo(2) translation that's giving you problems when the input arrays are 1d. axes=1 is the 'contraction' for 1d array. In other words, don't take the word descriptions in the documentation too literally. They just attempt to describe the action of the code; they aren't a formal specification.

einsum equivalents

I think the axes specifications for einsum are clearer and more powerful. Here are the equivalents for 0,1,2

In [295]: np.einsum('ij,kl',a,b)
Out[295]: 
array([[[[ 2,  3],
         [ 2,  1]],

        [[ 4,  6],
         [ 4,  2]]],


       [[[ 6,  9],
         [ 6,  3]],

        [[ 8, 12],
         [ 8,  4]]]])
In [296]: np.einsum('ij,jk',a,b)
Out[296]: 
array([[ 6,  5],
       [14, 13]])
In [297]: np.einsum('ij,ij',a,b)
Out[297]: 18

The axes=0 case, is equivalent to:

np.dot(a[:,:,None],b[:,None,:])

It adds a new last axis and new 2nd to last axis, and does a conventional dot product summing over those. But we usually do this sort of 'outer' multiplication with broadcasting:

a[:,:,None,None]*b[None,None,:,:]

While the use of 0,1,2 for axes is interesting, it really doesn't add new calculation power. The tuple form of axes is more powerful and useful.

code summary (big steps)

1 - translate axes into axes_a and axes_b as excerpted in the above foo function

2 - make a and b into arrays, and get the shape and ndim

3 - check for matching size on axes that will be summed (contracted)

4 - construct a newshape_a and newaxes_a; same for b (complex step)

5 - at = a.transpose(newaxes_a).reshape(newshape_a); same for b

6 - res = dot(at, bt)

7 - reshape the res to desired return shape

5 and 6 are the calculation core. 4 is conceptually the most complex step. For all axes values the calculation is the same, a dot product, but the setup varies.

beyond 0,1,2

While the documentation only mentions 0,1,2 for scalar axes, the code isn't restricted to those values

In [331]: foo(3)
Out[331]: ([-3, -2, -1], [0, 1, 2])

If the inputs are 3, axes=3 should work:

In [330]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3)
Out[330]: array(8.)

or more generally:

In [325]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=0).shape
Out[325]: (2, 2, 2, 2, 2, 2)
In [326]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=1).shape
Out[326]: (2, 2, 2, 2)
In [327]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=2).shape
Out[327]: (2, 2)
In [328]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3).shape
Out[328]: ()

and if the inputs are 0d, axes=0 works (axes = 1 does not):

In [335]: np.tensordot(2,3, axes=0)
Out[335]: array(6)

Can you explain this?

In [363]: np.tensordot(np.ones((4,2,3)),np.ones((2,3,4)),axes=2).shape
Out[363]: (4, 4)

I've played around with other scalar axes values for 3d arrays. While it is possible to come up with pairs of shapes that work, the more explicit tuple axes values is easier to work with. The 0,1,2 options are short cuts that only work for special cases. The tuple approach is much easier to use - though I still prefer the einsum notation.

like image 156
hpaulj Avatar answered Dec 20 '22 10:12

hpaulj


Example 1-0: np.tensordot([1, 1], [2, 2], axes=0)

In this case, a and b both have a single axis and have shape (2,).

The axes=0 argument can be translated to ((the last 0 axes of a), (the first 0 axes of b)), or in this case ((), ()). These are the axes that will be contracted.

All the other axes will not be contracted. Since each of a and b have a 0-th axis and no others, these are the axes ((0,), (0,)).

The tensordot operation is then as follows (roughly):

[
    [x*y for y in b]  # all the non-contraction axes in b
    for x in a        # all the non-contraction axes in a
]

Note that since there are 2 total axes available between a and b and since we're contracting 0 of them, the result has 2 axes. The shape is (2,2) since those are the shapes of the respective non-contracted axes in a and b (in order).

Example 1-1: np.tensordot([1, 1], [2, 2], axes=1)

The axes=1 argument can be translated to ((the last 1 axes of a), (the first 1 axes of b)), or in this case ((0,), (0,)). These are the axes that will be contracted

All other axes will not be contracted. Since we are already contracting every axis, the remaining axes are ((), ()).

The tensordot operation is then as follows:

sum(  # summing over contraction axis
    [x*y for x,y in zip(a, b)]  # contracted axes must line up
)

Note that since we're contracting all axes, the result is a scalar (or a 0-shaped tensor). In numpy, you just get a tensor with shape () representing 0 axes rather than an actual scalar.

Example 1-2: np.tensordot([1, 1], [2, 2], axes=2)

The reason this doesn't work is because neither a nor b have two separate axes to contract over.

Example 2-1: np.tensordot([[1,1],[1,1]], [[2,2],[2,2]], axes=1)

I'm skipping a couple of your examples since they aren't quite complicated enough to add more clarity than the first few I don't think.

In this case, a and b both have two axes available (allowing this problem to be a bit more interesting), and they both have shape (2,2).

The axes=1 argument still represents the last 1 axes of a and the first 1 axes of b, leaving us with ((1,), (0,)). These are the axes that will be contracted over.

The remaining axes are not contracted and contribute to the shape of the final solution. These are ((0,), (1,)).

We can then construct the tensordot operation. For the sake of argument, pretend a and b are numpy arrays so that we can use array properties and make the problem cleaner (e.g. b=np.array([[2,2],[2,2]])).

[
    [
        sum(  # summing the contracted indices
            [x*y for x,y in zip(v,w)]  # axis 1 of a and axis 0 of b must line up for the summation
        )
        for w in b.T  # iterating over axis 1 of b (i.e. the columns)
    ]
    for v in a  # iterating over axis 0 of a (i.e. the rows)
]

The result has shape (a.shape[0], b.shape[1]) since these are the non-contracted axes.

like image 21
Hans Musgrave Avatar answered Dec 20 '22 09:12

Hans Musgrave