Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Construct (N+1)-dimensional diagonal matrix from values in N-dimensional array

I have an N-dimensional array. I want to expand it to an (N+1)-dimensional array by putting the values of the final dimension in the diagonal.

For example, using explicit looping:

In [197]: M = arange(5*3).reshape(5, 3)

In [198]: numpy.dstack([numpy.diag(M[i, :]) for i in range(M.shape[0])]).T
Out[198]: 
array([[[ 0,  0,  0],
        [ 0,  1,  0],
        [ 0,  0,  2]],

       [[ 3,  0,  0],
        [ 0,  4,  0],
        [ 0,  0,  5]],

       [[ 6,  0,  0],
        [ 0,  7,  0],
        [ 0,  0,  8]],

       [[ 9,  0,  0],
        [ 0, 10,  0],
        [ 0,  0, 11]],

       [[12,  0,  0],
        [ 0, 13,  0],
        [ 0,  0, 14]]])

which is a 5×3×3 array.

My actual arrays are large and I would like to avoid explicit looping (hiding the looping in map instead of a list comprehension has no performance gain; it's still a loop). Although numpy.diag works for constructing a regular, 2-D diagonal matrix, it does not extend to higher dimensions (when given a 2-D array, it will extract its diagonal instead). The array returned by numpy.diagflat makes everything into one big diagonal, resulting in a 15×15 array which has far more zeroes and cannot be reshaped into 5×3×3.

Is there a way to efficiently construct an (N+1)-diagonal matrix from the values in a N-dimensional array, without calling diag many times?

like image 989
gerrit Avatar asked Feb 05 '18 16:02

gerrit


1 Answers

Use numpy.diagonal to take a view of the relevant diagonals of a properly-shaped N+1-dimensional array, force the view to be writeable with setflags, and write to the view:

expanded = numpy.zeros(M.shape + M.shape[-1:], dtype=M.dtype)

diagonals = numpy.diagonal(expanded, axis1=-2, axis2=-1)
diagonals.setflags(write=True)

diagonals[:] = M

This produces your desired array as expanded.

like image 115
user2357112 supports Monica Avatar answered Sep 21 '22 05:09

user2357112 supports Monica