Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy "Where" function can not avoid evaluate Sqrt(negative)

Tags:

python

numpy

It seems that the np.where function evaluates all the possible outcomes first, then it evaluates the condition later. This means that, in my case, it will evaluate square root of -5, -4, -3, -2, -1 even though it will not be used later on.

My code runs and works. But my problem is the warning. I avoided using a loop to evaluate each element, because it will run much slower than np.where.

So, here, I am asking

  1. Is there any way to make np.where evaluate the condition first?
  2. Can I turn off just this specific warning? How?
  3. Another better way to do it if you have a better suggestion.

Here just a short example code corresponding my real code which is gigantic. But essentially has the same problem.

Input:

import numpy as np

c=np.arange(10)-5
d=np.where(c>=0, np.sqrt(c) ,c )

Output:

RuntimeWarning: invalid value encountered in sqrt
d=np.where(c>=0,np.sqrt(c),c)
like image 626
MH Yip Avatar asked Oct 03 '18 07:10

MH Yip


4 Answers

There is a much better way of doing this. Let's take a look at what your code is doing to see why.

np.where accepts three arrays as inputs. Arrays do not support lazy evaluation.

d = np.where(c >= 0, np.sqrt(c), c)

This line is therefore equivalent to doing

a = (c >= 0)
b = np.sqrt(c)
d = np.where(a, b, c)

Notice that the inputs are computed immediately, before where ever gets called.

Luckily, you don't need to use where at all. Instead, just use a boolean mask:

mask = (c >= 0)
d = np.empty_like(c)
d[mask] = np.sqrt(c[mask])
d[~mask] = c[~mask]

If you expect a lot of negatives, you can copy all the elements instead of just the negative ones:

d = c.copy()
d[mask] = np.sqrt(c[mask])

An even better solution might be to use masked arrays:

d = np.ma.masked_array(c, c < 0)
d = np.ma.sqrt(d)

To access the whole data array, with the masked portion unaltered, use d.data.

like image 141
Mad Physicist Avatar answered Oct 03 '22 22:10

Mad Physicist


np.sqrt is a ufunc and accepts a where parameter. It can be used as a mask in this case:

In [61]: c = np.arange(10)-5.0
In [62]: d = c.copy()
In [63]: np.sqrt(c, where=c>=0, out=d);
In [64]: d
Out[64]: 
array([-5.        , -4.        , -3.        , -2.        , -1.        ,
        0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ])

In contrast to the np.where case, this does not evaluate the function at the ~where elements.

like image 21
hpaulj Avatar answered Oct 03 '22 23:10

hpaulj


This is answer to your 2nd question.

Yes you can turn off the warnings. Use warnings module.

import warnings
warnings.filterwarnings("ignore")
like image 20
Sociopath Avatar answered Oct 04 '22 00:10

Sociopath


One solution is to not use np.where, and use indexing instead.

c = np.arange(10)-5
d = c.copy()
c_positive = c > 0
d[c_positive] = np.sqrt(c[c_positive])
like image 20
shadowtalker Avatar answered Oct 03 '22 22:10

shadowtalker