Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to write conditional code that's compatible with both plain Python values and NumPy arrays?

For writing “piecewise functions” in Python, I'd normally use if (in either the control-flow or ternary-operator form).

def spam(x):
    return x+1 if x>=0 else 1/(1-x)

Now, with NumPy, the mantra is to avoid working on single values in favour of vectorisation, for performance. So I reckon something like this would be preferred:As Leon remarks, the following is wrong

def eggs(x):
    y = np.zeros_like(x)
    positive = x>=0
    y[positive] = x+1
    y[np.logical_not(positive)] = 1/(1-x)
    return y

(Correct me if I've missed something here, because frankly I find this very ugly.)

Now, of course eggs will only work if x is actually a NumPy array, because otherwise x>=0 simply yields a single boolean, which can't be used for indexing (at least doesn't do the right thing).

Is there a good way to write code that looks more like spam but works idiomatic on Numpy arrays, or should I just use vectorize(spam)?

like image 641
leftaroundabout Avatar asked Sep 08 '16 11:09

leftaroundabout


2 Answers

Use np.where. You'll get an array as the output even for plain number input, though.

def eggs(x):
    y = np.asarray(x)
    return np.where(y>=0, y+1, 1/(1-y))

This works for both arrays and plain numbers:

>>> eggs(5)
array(6.0)
>>> eggs(-3)
array(0.25)
>>> eggs(np.arange(-3, 3))
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:2: RuntimeWarning: divide by zero encountered in true_divide
array([ 0.25      ,  0.33333333,  0.5       ,  1.        ,  2.        ,  3.        ])
>>> eggs(1)
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:3: RuntimeWarning: divide by zero encountered in long_scalars
  # -*- coding: utf-8 -*-
array(2.0)

As ayhan remarks, this raises a warning, since 1/(1-x) gets evaluated for the whole range. But a warning is just that: a warning. If you know what you're doing, you can ignore the warning. In this case, you're only choosing 1/(1-x) from indices where it can never be inf, so you're safe.

like image 66
Praveen Avatar answered Sep 28 '22 20:09

Praveen


I would use numpy.asarray (which is a no-op if the argument is already an numpy array) if I want to handle both numbers and numpy arrays

def eggs(x):
    x = np.asfarray(x)
    m = x>=0
    x[m] = x[m] + 1
    x[~m] = 1 / (1 - x[~m])
    return x

(here I used asfarray to enforce a floating-point type, since your function requires floating-point computations).

This is less efficient than your spam function for single inputs, and arguably uglier. However it seems to be the easiest choice.

EDIT: If you want to ensure that x is not modified (as pointed out by Leon) you can replace np.asfarray(x) by np.array(x, dtype=np.float64), the array constructor copies by default.

like image 21
Maxim Avatar answered Sep 28 '22 20:09

Maxim