Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Automatically use list comprehension/map() recursion if a function is given a list

As a Mathematica user, I like functions that automatically "threads over lists" (as the Mathematica people call it - see http://reference.wolfram.com/mathematica/ref/Listable.html). That means that if a function is given a list instead of a single value, it automatically uses each list entry as an argument and returns a list of the results - e.g.

myfunc([1,2,3,4]) -> [myfunc(1),myfunc(2),myfunc(3),myfunc(4)]

I implemented this principle in Python like this:

def myfunc(x):    
    if isinstance(x,list):
        return [myfunc(thisx) for thisx in x]
    #rest of the function

Is this a good way to do it? Can you think of any downsides of this implementation or the strategy overall?

like image 585
zonksoft Avatar asked Dec 26 '22 18:12

zonksoft


2 Answers

That's a good way to do it. However, you would have to do it for each function you write. To avoid that, you could use a decorator like this one :

def threads(fun):
  def wrapper(element_or_list):
    if isinstance(element_or_list, list):
      return [fun(element) for element in element_or_list]
    else:
      return fun(element_or_list)

  return wrapper

@threads
def plusone(e):
  return e + 1

print(plusone(1))
print(plusone([1, 2, 3]))
like image 22
Scharron Avatar answered Dec 29 '22 06:12

Scharron


If this is something you're going to do in a lot of functions, you could use a Python decorator. Here's a simple but useful one.

def threads_over_lists(fn):
    def wrapped(x, *args, **kwargs):
        if isinstance(x, list):
            return [fn(e, *args, **kwargs) for e in x]
        return fn(x, *args, **kwargs)
    return wrapped

This way, just adding the line @threads_over_lists before your function would make it behave this way. For example:

@threads_over_lists
def add_1(val):
    return val + 1

print add_1(10)
print add_1([10, 15, 20])

# if there are multiple arguments, threads only over the first element,
# keeping others the same

@threads_over_lists
def add_2_numbers(x, y):
    return x + y

print add_2_numbers(1, 10)
print add_2_numbers([1, 2, 3], 10)

You should also consider whether you want this to vectorize only over lists, or also over other iterable objects like tuples and generators. This is a useful StackOverflow question for determining that. Be careful, though- a string is iterable, but you probably won't want your function operating on each character within it.

like image 106
David Robinson Avatar answered Dec 29 '22 08:12

David Robinson