Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sort a numpy matrix based on its diagonal

I have a matrix that should have ones on the diagonal but the columns are mixed up.

Messed up matrix

But I don't know how, without the obvious for loop, to efficiently interchange rows to get unity on the diagonals. I'm not even sure what key I would pass to sort on.

Any suggestions?

like image 363
mac389 Avatar asked Aug 22 '12 02:08

mac389


People also ask

How do you extract the diagonal elements of a matrix in NumPy?

diag() To Extract Diagonal. Numpy diag() function is used to extract or construct a diagonal 2-d array. It contains two parameters: an input array and k , which decides the diagonal, i.e., k=0 for the main diagonal, k=1 for the above main diagonal, or k=-1 for the below diagonal.


2 Answers

You can use numpy's argmax to determine the goal column ordering and reorder your matrix using the argmax results as column indices:

>>> z = numpy.array([[ 0.1 ,  0.1 ,  1.  ],
...                  [ 1.  ,  0.1 ,  0.09],
...                  [ 0.1 ,  1.  ,  0.2 ]])

numpy.argmax(z, axis=1)

>>> array([2, 0, 1]) #Goal column indices

z[:,numpy.argmax(z, axis=1)]

>>> array([[ 1.  ,  0.1 ,  0.1 ],
...        [ 0.09,  1.  ,  0.1 ],
...        [ 0.2 ,  0.1 ,  1.  ]])
like image 60
John Lyon Avatar answered Oct 15 '22 13:10

John Lyon


>>> import numpy as np
>>> a = np.array([[ 1. ,  0.5,  0.5,  0. ],
...               [ 0.5,  0.5,  1. ,  0. ],
...               [ 0. ,  1. ,  0. ,  0.5],
...               [ 0. ,  0.5,  0.5,  1. ]])
>>> np.array(sorted(a, cmp=lambda x, y: list(x).index(1) - list(y).index(1)))
array([[ 1. ,  0.5,  0.5,  0. ],
       [ 0. ,  1. ,  0. ,  0.5],
       [ 0.5,  0.5,  1. ,  0. ],
       [ 0. ,  0.5,  0.5,  1. ]])

It actually sorts by rows, not columns (but the result is the same). It works by sorting by the index of the column the 1 is in.

like image 36
Snowball Avatar answered Oct 15 '22 14:10

Snowball