Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Conditional operations on numpy arrays

I'm new to NumPy, and I've encountered a problem with running some conditional statements on numpy arrays. Let's say I have 3 numpy arrays that look like this:

a:

[[0, 4, 4, 2],
 [1, 3, 0, 2],
 [3, 2, 4, 4]]

b:

[[6, 9, 8, 6],
 [7, 7, 9, 6],
 [8, 6, 5, 7]]

and, c:

[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]

I have a conditional statement for a and b in which I would like to use the value of b (if the conditions of a and b are met) to calculate the value of c:

c[(a > 3) & (b > 8)]+=b*2

I get an error saying:

Traceback (most recent call last):
  File "<interactive input>", line 1, in <module>
ValueError: non-broadcastable output operand with shape (1,) doesn't match the broadcast shape (3,4)

Any idea how I can accomplish this?

I would like the output of c to look as follows:

[[0, 18, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]
like image 456
bobby12345 Avatar asked Mar 01 '17 19:03

bobby12345


People also ask

Which function in NumPy is used to run conditional execution?

numpy.where(condition[, x, y]) np. where() is a function that returns ndarray which is x if condition is True and y if False . x , y and condition need to be broadcastable to same shape. If x and y are omitted, index is returned.

What is indexing and slicing in NumPy?

Numpy with Python Three types of indexing methods are available − field access, basic slicing and advanced indexing. Basic slicing is an extension of Python's basic concept of slicing to n dimensions. A Python slice object is constructed by giving start, stop, and step parameters to the built-in slice function.

Are NumPy operations faster?

NumPy is fast because it can do all its calculations without calling back into Python. Since this function involves looping in Python, we lose all the performance benefits of using NumPy. For a 10,000,000-entry NumPy array, this functions takes 2.5 seconds to run on my computer.

Is NumPy array mutable or immutable?

Numpy Arrays are mutable, which means that you can change the value of an element in the array after an array has been initialized. Use the print function to view the contents of the array. Unlike Python lists, the contents of a Numpy array are homogenous.


2 Answers

You can use numpy.where:

np.where((a > 3) & (b > 8), c + b*2, c)
#array([[ 0, 18,  0,  0],
#       [ 0,  0,  0,  0],
#       [ 0,  0,  0,  0]])

Or arithmetically:

c + b*2 * ((a > 3) & (b > 8))
#array([[ 0, 18,  0,  0],
#       [ 0,  0,  0,  0],
#       [ 0,  0,  0,  0]])
like image 195
Psidom Avatar answered Oct 04 '22 04:10

Psidom


The problem is that you mask the receiving part, but do not mask the sender part. As a result:

c[(a > 3) & (b > 8)]+=b*2
# ^ 1x1 matrix        ^3x4 matrix

The dimensions are not the same. Given you want to perform element-wise addition (based on your example), you can simply add the slicing to the right part as well:

c[(a > 3) & (b > 8)]+=b[(a > 3) & (b > 8)]*2

or make it more efficient:

mask = (a > 3) & (b > 8)
c[mask] += b[mask]*2
like image 26
Willem Van Onsem Avatar answered Oct 04 '22 05:10

Willem Van Onsem