Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Injecting "global imports" into Python functions



Short, but complete, summary

I want to allow users of my function (a class factory) to inject/overwrite global imports when using my function (longer explanation of rationale below). But there are about 10 different variables that could be passed in and it adds a number of very repetitive lines to the code. (granted, also makes it more complicated to call too :P) Right now, I'm doing something like the following (just simplifying all of this). To make it runnable, I'm using a dummy class, but in the actual script I'd be using import pkg1, etc. Figured this was clearer and shorter than a class factory, etc.

class Dummy(object): pass

pkg1, pkg2 = Dummy(), Dummy()
pkg1.average = lambda *args : sum(args) / len(args)
pkg2.get_lengths = lambda *args : map(len, args)

def get_average(*args, **kwargs):
    average = kwargs.get("average") or pkg1.average
    get_lengths = kwargs.get("get_lengths") or pkg2.get_lengths
    return average(*get_lengths(*args))

adjusted_length = lambda *args: map(len, args) + [15]
print get_average([1,2], [10, 4, 5, 6]) == 3 # True
print get_average([1,2], [10, 4, 5, 6], get_lengths=adjusted_length) == 7 # True

Related SO questions

This stack overflow post: Modifying locals in Python, seemed particularly relevant and initially I wanted to just overwrite locals by storing to the locals dictionary but (1) it didn't seem to work, and (2) it seems like it was a bad idea. So, I'm wondering if there's another way to do it.

This one looked promising ( Adding an object to another module's globals in python ), but I'm not really sure how to access the globals for the current file in the same way as a module. (and this question - python: mutating `globals` to dynamically put things in scope - doesn't really apply, since I'm (ultimately) using this to define classes).

I guess I could wrap everything in an exec statement (like this post - globals and locals in python exec() ), but that's both fiddly and means that it's much harder to do error checking/linting/etc.

So here's what I'd like to do. (NOTE: I would have used from pkg1 import average AND from pkg2 import get_lengths but I wanted the example to be clearer (need to copy pkg1 and pkg2 above to run this))

average = pkg1.average
get_lengths = pkg2.get_lengths

def get_average(*args, **kwargs):
    localvars = locals()
    for k in ("get_lengths", "average"):
        if kwargs.get(k, None) and kwargs[k] is not None:
            localvars[k] = kwargs[k]
    return average(*get_lengths(*args))

print get_average([1,2], [10, 4, 5, 6]) == 3 #True
print get_average([1,2], [10, 4, 5, 6], get_lengths=adjusted_length) == 7 # False, is 3

Rationale for my specific use-case

Right now, I'm trying to write a dynamically-generated class factory (to use as an SQLAlchemy mixin), but I want to allow users of my class to pass in alternate constructors, so they can use SQLAlchemy adapters, etc.

For example, Flask-SQLAlchemy provides the same interface as SQLAlchemy, but provides a custom object/class (db) that wraps around all the SQLAlchemy objects to provide more features.

like image 398
Jeff Tratner Avatar asked Nov 04 '22 19:11

Jeff Tratner

1 Answers

You could use arguments with default values to pass functions in. This is effectively what you are doing but cleaner. I've used lists as a single argument instead of *args because it is easier to deal with when you have other arguments. You'll have to enclose your lists in a tuple to pass them in to get_average.

The builtin function sorted works like this so it should be easy for Python programmers to understand.

get_average(lists, average=pkg1.average, get_lengths=pkg2.get_lengths):
    return average(*get_lengths(*lists))

print get_average(([1,2], [10, 4, 5, 6]))
print get_average(([1,2], [10, 4, 5, 6]), get_lengths=adjusted_length)

If you have many keyword arguments you could package them in an object:

class GetAverageContext(object):
    def __init__(self, average=pkg1.average, get_lengths=pkg2.get_lengths):
        self.average = average
        self.get_lengths = get_lengths

DefaultGetAverageContext = GetAverageContext()

def get_average(lists, context=DefaultGetAverageContext):
    return context.average(*context.get_lengths(*lists))
like image 92
Peter Graham Avatar answered Nov 14 '22 21:11

Peter Graham