Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

A python function that accepts as an argument either a scalar or a numpy array

As the title says, suppose I want to write a sign function (let's forget sign(0) for now), obviously we expect sign(2) = 1 and sign(array([-2,-2,2])) = array([-1,-1,1]). The following function won't work however, because it can't handle numpy arrays.

def sign(x):
    if x>0: return 1
    else: return -1

The next function won't work either since x doesn't have a shape member if it's just a single number. Even if some trick like y = x*0 + 1 is used, y won't have a [] method.

def sign(x):
    y = ones(x.shape)
    y[x<0] = -1
    return y

Even with the idea from another question(how can I make a numpy function that accepts a numpy array, an iterable, or a scalar?), the next function won't work when x is a single number because in this case x.shape and y.shape are just () and indexing y is illegal.

def sign(x):
    x = asarray(x)
    y = ones(x.shape)
    y[x<0] = -1
    return y

The only solution seems to be that first decide if x is an array or a number, but I want to know if there is something better. Writing branchy code would be cumbersome if you have lots of small functions like this.

like image 490
Taozi Avatar asked Oct 24 '14 06:10

Taozi


People also ask

What is scalar function in Python?

Python defines only one type of a particular data class (there is only one integer type, one floating-point type, etc.). This can be convenient in applications that don't need to be concerned with all the ways data can be represented in a computer.

What is a scalar array in Python?

Array scalars have the same attributes and methods as ndarrays . 1 This allows one to treat items of an array partly on the same footing as arrays, smoothing out rough edges that result when mixing scalar and array operations. Array scalars live in a hierarchy (see the Figure below) of data types.

What is a scalar array NumPy?

In NumPy, a scalar is any object that you put in an array. It's similar to the concept in linear algebra, an element of a field which is used to define a vector space. NumPy ensures all scalars in an array have same types. It's impossible one scalar having type int32, the other scalars having type int64.

What are the NumPy functions in Python?

What is NumPy? NumPy is a Python library used for working with arrays. It also has functions for working in domain of linear algebra, fourier transform, and matrices. NumPy was created in 2005 by Travis Oliphant.


1 Answers

np.vectorize can be used to achieve that, but would be slow because all it does, when your decorated function is called with an array, is looping through the array elements and apply the scalar function to each, i.e. not leveraging numpy's speed.

A method I find useful for vectorizing functions involving if-else is using np.choose:

def sign_non_zero(x):
    return np.choose(
        x > 0,  # bool values, used as indices to the array
        [
            -1, # index=0=False, i.e. x<=0
            1,  # index=1=True, i.e. x>0
        ])

This works when x is either scalar or an array, and is faster than looping in python-space.

The only disadvantage of using np.choose is that it is not intuitive to write if-else logic in that manner, and the code is less readable. Whenver I use it, I include comments like the ones above, to make it easier on the reader to understand what is going on.

like image 138
shx2 Avatar answered Sep 20 '22 16:09

shx2