I'm trying to write a decorator that provides method overloading functionality to python, similar to the one mentioned in PEP 3124.
The decorator I wrote works great for regular functions, but I can't get it to work for methods in a class.
Here is the decorator:
class Overload(object):
def __init__(self, default):
self.default_function = default
self.type_map = {}
self.pos = None
def __call__(self, *args, **kwargs):
print self
try:
if self.pos is None:
pos = kwargs.get("pos", 0)
else:
pos = self.pos
print args, kwargs
return self.type_map[type(args[pos])](*args, **kwargs)
except KeyError:
return self.default_function(*args, **kwargs)
except IndexError:
return self.default_function(*args, **kwargs)
def overload(self, *d_type):
def wrapper(f):
for dt in d_type:
self.type_map[dt] = f
return self
return wrapper
When I attempt to implement it like this:
class MyClass(object):
def __init__(self):
self.some_instance_var = 1
@Overload
def print_first_item(self, x):
return x[0], self.some_instance_var
@print_first_item.overload(str)
def print_first_item(self, x):
return x.split()[0], self.some_instance_var
I get a TypeError
when I run it:
>>> m = MyClass()
>>> m.print_first_item(1)
<__main__.Overload object at 0x2> (1,) {}
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "overload.py", line 17, in __call__
return self.default_function(*args, **kwargs)
TypeError: print_first_item() takes exactly 2 arguments (1 given)
>>>
My question is: How can I access the instance of MyClass
(i.e. self
) from within the decorated method?
Essentially, your Overload
class needs a __get__
method:
def __get__(self, obj, cls):
# Called on access of MyClass.print_first_item.
# We return a wrapper which calls our
print "get", self, obj, cls
if obj is None:
# a function would do some checks here, but we leave that.
return self
else:
return lambda *a, **k: self(obj, *a, **k)
Why?
Well, you use your Overload
object as a kind of function replacement. You want it, like a function, to represent itself in a method context with different signature.
Short explanation how method access works:
object.meth(1, 2)
gets translated to
object.__dict__['meth'].__get__(object, type(object))(1, 2)
A function's __get__()
returns a method object which wraps the function by prepending the object to the parameter list (where it results in self
):
realmethod = object.__dict__['meth'].__get__(object, type(object))
realmethod(1, 2)
where realmethod
is a method object which knows the function to be called and the self
to be given to it and calls the "real" function appropriately by transforming the call into
meth(object, 1, 2)
.
This behaviour we imitate in this new __get__
method.
as abarnert says as you are using a class as your decorator 'self' is an instance of Overload rather than MyClass as you hope/expect.
I couldn't find a simple solution. The best thing I could come up with is not using a class as a decorator and instead use a function but with a second argument with a default of a dictionary. Since this is an mutable type it will be the same dictionary every time the function is called. I use this to store my 'class variables'. The rests folows a similar pattern to your solution.
Example:
import inspect
def overload(funcOrType, map={}, type=None):
if not inspect.isclass(funcOrType):
# We have a function so we are dealing with "@overload"
if(type):
map[type] = funcOrType
else:
map['default_function'] = funcOrType
else:
def overloadWithType(func):
return overload(func, map, funcOrType)
return overloadWithType
def doOverload(*args, **kwargs):
for type in [t for t in map.keys() if t != 'default_function'] :
if isinstance(args[1], type): # Note args[0] is 'self' i.e. MyClass instance.
return map[type](*args, **kwargs)
return map['default_function'](*args, **kwargs)
return doOverload
Then:
from overload import *
class MyClass(object):
def __init__(self):
self.some_instance_var = 1
@overload
def print_first_item(self, x):
return x[0], self.some_instance_var
@overload(str)
def print_first_item(self, x):
return x.split()[0], self.some_instance_var
m = MyClass()
print (m.print_first_item(['a','b','c']))
print (m.print_first_item("One Two Three"))
Yeilds:
('a', 1)
('One', 1)
For reference, here is the working implementation, thanks to the detailed explanation by glglgl:
argtype_tuple = lambda args: tuple(type(a) for a in args)
class Overload(object):
def __init__(self, func):
self.default = func
self.map = {}
def __call__(self, *args, **kwargs):
key_tuple = argtype_tuple(args)
c_inst = kwargs.pop("c_inst", None)
if c_inst:
args = (c_inst,) + args
try:
return self.map[key_tuple](*args, **kwargs)
except KeyError:
return self.default(*args, **kwargs)
def __get__(self, obj, cls):
if obj:
return lambda *args, **kwargs: self(c_inst=obj, *args, **kwargs)
else:
return self
def overload(self, *types):
def wrapper(f):
for type_seq in types:
if type(type_seq) == tuple:
type_seq = tuple(type_seq)
else:
type_seq = (type_seq,)
self.map[type_seq] = f
return self
return wrapper
#Some tests/usage examples
class A(object):
@Overload
def print_first(self, x):
return x[0]
@print_first.overload(str)
def p_first(self, x):
return x.split()[0]
def __repr__(self):
return "class A Instance"
a = A()
assert a.print_first([1,2,3]) == 1
assert a.print_first("one two three") == "one"
@Overload
def flatten(seq):
return [seq]
@flatten.overload(list, tuple)
def flat(seq):
return sum((flatten(item) for item in seq), [])
assert flatten([1,2,[3,4]]) == [1,2,3,4]
assert flat([1,2,[3,4]]) == [1,2,3,4]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With