Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

conditional operation on numpy multidimensional array

I am a naive numpy user, need your help for the following problem: I want to replace some elements of a multidimensional array which are less than a second array by a third array; e.g.:

x = np.arange(16).reshape((2, 8)) 
# x = np.array([[ 0,  1,  2,  3,  4,  5,  6,  7],
#               [ 8,  9, 10, 11, 12, 13, 14, 15]])

And

y = np.array([[2], [13]])
# y = np.array([[ 2], [13]])

Now, find out where x is greater than y, and if there is at least one True in x > y array, count these instances, create another array (z) and replace x in these elements with z:

x > y 
# = [[False, False, False, True,  True,  True,  True, True],
#    [False, False, False, False, False, False, True, True]]

In this case 5 elements of x (x[:,3:]) should be replaced, so we create a (5, 2) array:

z = np.array([[20,21],[22,23],[24,25],[26,27],[28,29]])

The result I want is

x == np.array([[ 0,  1,  2, 20, 22, 24, 26, 28],
               [ 8,  9, 10, 21, 23, 25, 27, 29]])
like image 682
Aso Agile Avatar asked Oct 08 '22 05:10

Aso Agile


1 Answers

A numpy function that does almost exactly what you want is numpy.where:

x = np.arange(16).reshape((2, 8))
y = np.array([[2], [13]])
z = np.arange(16, 32).reshape((2, 8))
numpy.where(~(x > y).any(axis=0), x, z)

Result:

array([[ 0,  1,  2, 19, 20, 21, 22, 23],
       [ 8,  9, 10, 27, 28, 29, 30, 31]])

The only difference between this and what you asked for is that z has to be broadcastable to the same shape as x. Unless you absolutely need to use a z value with only as many columns as there are True values in ~(x > y).any(axis=0), I think this is the best approach.

However, given your comments, it seems that you do need to use a z value as described above. It sounds like the function guarantees that the shapes will match up, so you can probably just do this:

x[:,(x > y).any(axis=0)] = z.T

Tested:

>>> z = np.arange(20, 30).reshape((5, 2))
>>> x[:,(x > y).any(axis=0)] = z.T
>>> x
array([[ 0,  1,  2, 20, 22, 24, 26, 28],
       [ 8,  9, 10, 21, 23, 25, 27, 29]])
like image 162
senderle Avatar answered Oct 12 '22 12:10

senderle