Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy einsum to get axes permutation

What I understood in the documentation of ‘np.einsum‘ is that a permutation string, would give a permutation of the axis in a vector. This is confirmed by the following experiment:

>>> M = np.arange(24).reshape(2,3,4)
>>> M.shape
(2, 3, 4)
>>> np.einsum('ijk', M).shape
(2, 3, 4)
>>> np.einsum('ikj', M).shape
(2, 4, 3)
>>> np.einsum('jik', M).shape
(3, 2, 4)

But this I cannot understand:

>>> np.einsum('kij', M).shape
(3, 4, 2)

I would expect (4, 2, 3) instead... What's wrong with my understanding?

like image 579
Emanuele Paolini Avatar asked Jan 30 '15 09:01

Emanuele Paolini


1 Answers

When the output signature is not specified (i.e. there's no '->' in the subscripts string), einsum will create it by taking the letters it's been given and arranging them in alphabetical order.

This means that

np.einsum('kij', M)

is actually equivalent to

np.einsum('kij->ijk', M)

So writing 'kij' labels the axes of the input matrix, not the output matrix, and this leads to the permutation of the axes that you observed.

This point isn't made explicit in the documentation, but can be seen commented in the C source code for einsum:

/*
 * If there is no output signature, create one using each label
 * that appeared once, in alphabetical order
 */

To ensure the axes of M are permuted in the intended order, it may be necessary to give einsum the labeling for both the input and output matrices:

>>> np.einsum('ijk->kij', M).shape
(4, 2, 3)
like image 146
Alex Riley Avatar answered Oct 25 '22 22:10

Alex Riley