It seems that the np.where
function evaluates all the possible outcomes first, then it evaluates the condition later. This means that, in my case, it will evaluate square root of -5, -4, -3, -2, -1 even though it will not be used later on.
My code runs and works. But my problem is the warning. I avoided using a loop to evaluate each element, because it will run much slower than np.where
.
So, here, I am asking
np.where
evaluate the condition first?Here just a short example code corresponding my real code which is gigantic. But essentially has the same problem.
Input:
import numpy as np
c=np.arange(10)-5
d=np.where(c>=0, np.sqrt(c) ,c )
Output:
RuntimeWarning: invalid value encountered in sqrt
d=np.where(c>=0,np.sqrt(c),c)
There is a much better way of doing this. Let's take a look at what your code is doing to see why.
np.where
accepts three arrays as inputs. Arrays do not support lazy evaluation.
d = np.where(c >= 0, np.sqrt(c), c)
This line is therefore equivalent to doing
a = (c >= 0)
b = np.sqrt(c)
d = np.where(a, b, c)
Notice that the inputs are computed immediately, before where
ever gets called.
Luckily, you don't need to use where
at all. Instead, just use a boolean mask:
mask = (c >= 0)
d = np.empty_like(c)
d[mask] = np.sqrt(c[mask])
d[~mask] = c[~mask]
If you expect a lot of negatives, you can copy all the elements instead of just the negative ones:
d = c.copy()
d[mask] = np.sqrt(c[mask])
An even better solution might be to use masked arrays:
d = np.ma.masked_array(c, c < 0)
d = np.ma.sqrt(d)
To access the whole data array, with the masked portion unaltered, use d.data
.
np.sqrt
is a ufunc
and accepts a where
parameter. It can be used as a mask in this case:
In [61]: c = np.arange(10)-5.0
In [62]: d = c.copy()
In [63]: np.sqrt(c, where=c>=0, out=d);
In [64]: d
Out[64]:
array([-5. , -4. , -3. , -2. , -1. ,
0. , 1. , 1.41421356, 1.73205081, 2. ])
In contrast to the np.where
case, this does not evaluate the function at the ~where elements.
This is answer to your 2nd question.
Yes you can turn off the warnings. Use warnings module.
import warnings
warnings.filterwarnings("ignore")
One solution is to not use np.where
, and use indexing instead.
c = np.arange(10)-5
d = c.copy()
c_positive = c > 0
d[c_positive] = np.sqrt(c[c_positive])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With