Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to access lower triangle of N*M*M numpy array

I have a numpy array of the shape arr.shape = N,M,M.

I want to access the lower triangles for each M,M array. I tried using

arr1 = arr[:,np.tril_indices(M,-1)]
arr1 = arr[:][np.tril_indices(M,-1)]

etc, with the kernel dying in the first case, while in the second case I get an error saying that:

   ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-23-1b36c5b12706> in <module>
----> 1 arr1 = arr[:][np.tril_indices(M,-1)]

IndexError: index 6 is out of bounds for axis 0 with size 6

Where

N=6

To clarify I want to find all the elements in the lower triangle of each M,M array(N such instances) and save the result in a new array of the shape:

arr1.shape = (N,(M*(M-1))/2)

Edit:

While np.tril(arr) works, it results in an array

arr1 = np.tril(arr)
arr1.shape

#(N,M,M)

I want the resulting array to be of the specified shape, i.e. I dont want the upper parts of the arrays

Thank you

like image 641
Aayush Desai Avatar asked Oct 26 '25 14:10

Aayush Desai


1 Answers

When working with the tri... set of functions it can be useful to examine the source code. They are all python, and based on np.tri.

Make a small sample array - to illustrate and verify the answer:

In [205]: arr = np.arange(18).reshape(2,3,3)  # arange(1,19) might be better
In [206]: arr
Out[206]: 
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

tril sets the upper triangle values to 0. It works in this case, but application to 3d arrays is not documented.

In [207]: np.tril(arr) 
Out[207]: 
array([[[ 0,  0,  0],
        [ 3,  4,  0],
        [ 6,  7,  8]],

       [[ 9,  0,  0],
        [12, 13,  0],
        [15, 16, 17]]])

But in the code if first constructs a boolean mask from the last 2 dimensions:

In [208]: mask = np.tri(*arr.shape[-2:], dtype=bool)
In [209]: mask
Out[209]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])

and uses np.where to set some values to 0. This works in the 3d case by broadcasting. mask and arr match on the last 2 dimensions, so mask can broadcast to match:

In [210]: np.where(mask, arr, 0)
Out[210]: 
array([[[ 0,  0,  0],
        [ 3,  4,  0],
        [ 6,  7,  8]],

       [[ 9,  0,  0],
        [12, 13,  0],
        [15, 16, 17]]])

Your tril_indices is just the indices of this mask:

In [217]: np.nonzero(mask)    # aka np.where
Out[217]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [218]: np.tril_indices(3)
Out[218]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))

They can't be used directly to index arr:

In [220]: arr[np.tril_indices(3)].shape
Traceback (most recent call last):
  File "<ipython-input-220-e26dc1f514cc>", line 1, in <module>
    arr[np.tril_indices(3)].shape
IndexError: index 2 is out of bounds for axis 0 with size 2

In [221]: arr[:,np.tril_indices(3)].shape
Out[221]: (2, 2, 6, 3)

But unpacking the two indexing arrays:

In [222]: I,J = np.tril_indices(3)
In [223]: I,J
Out[223]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [224]: arr[:,I,J]
Out[224]: 
array([[ 0,  3,  4,  6,  7,  8],
       [ 9, 12, 13, 15, 16, 17]])

The boolean mask can also be used directly:

In [226]: arr[:,mask]
Out[226]: 
array([[ 0,  3,  4,  6,  7,  8],
       [ 9, 12, 13, 15, 16, 17]])

The base np.tri works by simply doing an outer >= on indices

In [231]: m = np.greater_equal.outer(np.arange(3),np.arange(3))
In [232]: m
Out[232]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])
In [234]: np.arange(3)[:,None]>=np.arange(3)
Out[234]: 
array([[ True, False, False],
       [ True,  True, False],
       [ True,  True,  True]])
like image 112
hpaulj Avatar answered Oct 29 '25 06:10

hpaulj



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!