Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Method overloading decorator

I'm trying to write a decorator that provides method overloading functionality to python, similar to the one mentioned in PEP 3124.

The decorator I wrote works great for regular functions, but I can't get it to work for methods in a class.

Here is the decorator:

class Overload(object):
    def __init__(self, default):
        self.default_function = default
        self.type_map = {}
        self.pos = None

    def __call__(self, *args, **kwargs):
        print self
        try:
            if self.pos is None:
                pos = kwargs.get("pos", 0)
            else:
                pos = self.pos
            print args, kwargs
            return self.type_map[type(args[pos])](*args, **kwargs)
        except KeyError:
            return self.default_function(*args, **kwargs)
        except IndexError:
            return self.default_function(*args, **kwargs)

    def overload(self, *d_type):
        def wrapper(f):
            for dt in d_type:
                self.type_map[dt] = f
            return self
        return wrapper

When I attempt to implement it like this:

class MyClass(object):
    def __init__(self):
        self.some_instance_var = 1

    @Overload
    def print_first_item(self, x):
        return x[0], self.some_instance_var

    @print_first_item.overload(str)
    def print_first_item(self, x):
        return x.split()[0], self.some_instance_var

I get a TypeError when I run it:

>>> m = MyClass()
>>> m.print_first_item(1) 
<__main__.Overload object at 0x2> (1,) {} 
Traceback (most recent call last):   
  File "<stdin>", line 1, in <module>   
  File "overload.py", line 17, in __call__
    return self.default_function(*args, **kwargs) 
  TypeError: print_first_item() takes exactly 2 arguments (1 given)
>>>

My question is: How can I access the instance of MyClass (i.e. self) from within the decorated method?

like image 428
Joel Cornett Avatar asked Jul 07 '12 00:07

Joel Cornett


3 Answers

Essentially, your Overload class needs a __get__ method:

def __get__(self, obj, cls):
    # Called on access of MyClass.print_first_item.
    # We return a wrapper which calls our 
    print "get", self, obj, cls
    if obj is None:
        # a function would do some checks here, but we leave that.
        return self
    else:
        return lambda *a, **k: self(obj, *a, **k)

Why?

Well, you use your Overload object as a kind of function replacement. You want it, like a function, to represent itself in a method context with different signature.

Short explanation how method access works:

object.meth(1, 2)

gets translated to

object.__dict__['meth'].__get__(object, type(object))(1, 2)

A function's __get__() returns a method object which wraps the function by prepending the object to the parameter list (where it results in self):

realmethod = object.__dict__['meth'].__get__(object, type(object))
realmethod(1, 2)

where realmethod is a method object which knows the function to be called and the self to be given to it and calls the "real" function appropriately by transforming the call into

meth(object, 1, 2)

.

This behaviour we imitate in this new __get__ method.

like image 177
glglgl Avatar answered Sep 24 '22 00:09

glglgl


as abarnert says as you are using a class as your decorator 'self' is an instance of Overload rather than MyClass as you hope/expect.

I couldn't find a simple solution. The best thing I could come up with is not using a class as a decorator and instead use a function but with a second argument with a default of a dictionary. Since this is an mutable type it will be the same dictionary every time the function is called. I use this to store my 'class variables'. The rests folows a similar pattern to your solution.

Example:

import inspect

def overload(funcOrType, map={}, type=None):
    if not inspect.isclass(funcOrType):
        # We have a function so we are dealing with "@overload"
        if(type):
            map[type] = funcOrType
        else:
            map['default_function'] = funcOrType
    else:
        def overloadWithType(func):
            return overload(func, map, funcOrType)
        return  overloadWithType

    def doOverload(*args, **kwargs):
        for type in [t for t in map.keys() if t != 'default_function'] :
            if isinstance(args[1], type): # Note args[0] is 'self' i.e. MyClass instance.
                return map[type](*args, **kwargs)
        return map['default_function'](*args, **kwargs)

    return doOverload

Then:

from overload import *

class MyClass(object):
    def __init__(self):
        self.some_instance_var = 1

    @overload
    def print_first_item(self, x):
        return x[0], self.some_instance_var

    @overload(str)
    def print_first_item(self, x):
        return x.split()[0], self.some_instance_var


m = MyClass()
print (m.print_first_item(['a','b','c']))
print (m.print_first_item("One Two Three"))

Yeilds:

('a', 1)
('One', 1)
like image 21
tam203 Avatar answered Sep 20 '22 00:09

tam203


For reference, here is the working implementation, thanks to the detailed explanation by glglgl:

argtype_tuple = lambda args: tuple(type(a) for a in args)

class Overload(object):    
    def __init__(self, func):
        self.default = func
        self.map = {}

    def __call__(self, *args, **kwargs):
        key_tuple = argtype_tuple(args)
        c_inst = kwargs.pop("c_inst", None)
        if c_inst:
            args = (c_inst,) + args
        try:
            return self.map[key_tuple](*args, **kwargs)
        except KeyError:
            return self.default(*args, **kwargs)

    def __get__(self, obj, cls):
        if obj:
            return lambda *args, **kwargs: self(c_inst=obj, *args, **kwargs)
        else:
            return self

    def overload(self, *types):
        def wrapper(f):
            for type_seq in types:
                if type(type_seq) == tuple:
                    type_seq = tuple(type_seq)
                else:
                    type_seq = (type_seq,)
                self.map[type_seq] = f
            return self
        return wrapper

#Some tests/usage examples
class A(object):
    @Overload
    def print_first(self, x):
        return x[0]

    @print_first.overload(str)
    def p_first(self, x):
        return x.split()[0]

    def __repr__(self):
        return "class A Instance"

a = A()
assert a.print_first([1,2,3]) == 1
assert a.print_first("one two three") == "one"

@Overload
def flatten(seq):
    return [seq]

@flatten.overload(list, tuple)
def flat(seq):
    return sum((flatten(item) for item in seq), [])

assert flatten([1,2,[3,4]]) == [1,2,3,4]
assert flat([1,2,[3,4]]) == [1,2,3,4]
like image 45
Joel Cornett Avatar answered Sep 22 '22 00:09

Joel Cornett