For writing “piecewise functions” in Python, I'd normally use if
(in either the control-flow or ternary-operator form).
def spam(x):
return x+1 if x>=0 else 1/(1-x)
Now, with NumPy, the mantra is to avoid working on single values in favour of vectorisation, for performance. So I reckon something like this would be preferred:As Leon remarks, the following is wrong
def eggs(x):
y = np.zeros_like(x)
positive = x>=0
y[positive] = x+1
y[np.logical_not(positive)] = 1/(1-x)
return y
(Correct me if I've missed something here, because frankly I find this very ugly.)
Now, of course eggs
will only work if x
is actually a NumPy array, because otherwise x>=0
simply yields a single boolean, which can't be used for indexing (at least doesn't do the right thing).
Is there a good way to write code that looks more like spam
but works idiomatic on Numpy arrays, or should I just use vectorize(spam)
?
Use np.where
. You'll get an array as the output even for plain number input, though.
def eggs(x):
y = np.asarray(x)
return np.where(y>=0, y+1, 1/(1-y))
This works for both arrays and plain numbers:
>>> eggs(5)
array(6.0)
>>> eggs(-3)
array(0.25)
>>> eggs(np.arange(-3, 3))
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:2: RuntimeWarning: divide by zero encountered in true_divide
array([ 0.25 , 0.33333333, 0.5 , 1. , 2. , 3. ])
>>> eggs(1)
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:3: RuntimeWarning: divide by zero encountered in long_scalars
# -*- coding: utf-8 -*-
array(2.0)
As ayhan remarks, this raises a warning, since 1/(1-x)
gets evaluated for the whole range. But a warning is just that: a warning. If you know what you're doing, you can ignore the warning. In this case, you're only choosing 1/(1-x)
from indices where it can never be inf
, so you're safe.
I would use numpy.asarray (which is a no-op if the argument is already an numpy array) if I want to handle both numbers and numpy arrays
def eggs(x):
x = np.asfarray(x)
m = x>=0
x[m] = x[m] + 1
x[~m] = 1 / (1 - x[~m])
return x
(here I used asfarray to enforce a floating-point type, since your function requires floating-point computations).
This is less efficient than your spam function for single inputs, and arguably uglier. However it seems to be the easiest choice.
EDIT: If you want to ensure that x
is not modified (as pointed out by Leon) you can replace np.asfarray(x)
by np.array(x, dtype=np.float64)
, the array constructor copies by default.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With