Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

@classmethod not invoking my custom descriptor's __get__

I have a decorator called Special that turns a function into two versions of itself: one that can be called directly and prefixes the result with 'regular ' and one that can be called with .special and prefixes the results with 'special ':

class Special:
    def __init__(self, func):
        self.func = func

    def __get__(self, instance, owner=None):
        if instance is None:
            return self
        return Special(self.func.__get__(instance, owner))

    def special(self, *args, **kwargs):
        return 'special ' + self.func(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return 'regular ' + self.func(*args, **kwargs)

It works fine with regular methods and static methods - but .special does not work with class methods:

class Foo:
    @Special
    def bar(self):
        return 'bar'

    @staticmethod
    @Special
    def baz():
        return 'baz'

    @classmethod
    @Special
    def qux(cls):
        return 'qux'

assert Foo().bar() == 'regular bar'
assert Foo().bar.special() == 'special bar'

assert Foo.baz() == 'regular baz'
assert Foo.baz.special() == 'special baz'

assert Foo.qux() == 'regular qux'
assert Foo.qux.special() == 'special qux'  # TypeError: qux() missing 1 required positional argument: 'cls'
  • Foo().bar is invoking __get__, which binds the underlying function and passes the bound method to a new instance of Special - which is why both Foo().bar() and Foo().bar.special() work.

  • Foo.baz is simply returning the original Special instance - where both the regular and the special calls are simple.

  • Foo.qux is binding without calling my __get__.

    • The new bound object knows to pass the class as the first argument when being called directly - so Foo.qux() works.
    • Foo.qux.special is simply calling the .special of the underlying function (classmethod does not know how to bind it) - so Foo.qux.special() is invoking an unbound function, hence the TypeError.

Is there any way for Foo.qux.special to know it's being called from a classmethod? Or some other way around this problem?

like image 627
Idan Arye Avatar asked Jan 27 '23 22:01

Idan Arye


2 Answers

The problem is that classmethod.__get__ doesn't call the wrapped function's __get__—because that's basically the whole point of @classmethod. You can look at the pure-Python equivalent to classmethod in the Descriptors HOWTO, or the actual CPython C source in funcobject.c, for details, but you'll see that there's really no way around this.


Of course if you just @Special a @classmethod, instead of the other way around, everything will work fine when called on an instance:

class Foo:
    @Special
    @classmethod
    def spam(cls):
        return 'spam'

assert Foo().spam() == 'regular spam'
assert Foo().spam.special() == 'special spam'

… but now it won't work when called on the class:

assert Foo.spam() == 'regular spam'
assert Foo.spam.special() == 'special spam'

… because you're trying to call a classmethod object, which isn't callable.


But this problem, unlike the previous one, is fixable. In fact, the only reason this fails is this part:

if instance is None:
    return self

When you try to bind a Special instance to a class, it just returns self instead of binding its wrapped object. Which means it ends up as just a wrapper around a classmethod object rather than a wrapper around a bound class method, and of course you can't call a classmethod object.

But if you just leave that out, it'll let the underlying classmethod bind the same way a normal function does, which does exactly the right thing, and now everything works:

class Special:
    def __init__(self, func):
        self.func = func

    def __get__(self, instance, owner=None):
        return Special(self.func.__get__(instance, owner))

    def special(self, *args, **kwargs):
        return 'special ' + self.func(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return 'regular ' + self.func(*args, **kwargs)

class Foo:
    @Special
    def bar(self):
        return 'bar'

    @Special
    @staticmethod
    def baz():
        return 'baz'

    @Special
    @classmethod
    def qux(cls):
        return 'qux'

assert Foo().bar() == 'regular bar'
assert Foo().bar.special() == 'special bar'

assert Foo.baz() == 'regular baz'
assert Foo.baz.special() == 'special baz'
assert Foo().baz() == 'regular baz'
assert Foo().baz.special() == 'special baz'

assert Foo.qux() == 'regular qux'
assert Foo.qux.special() == 'special qux'
assert Foo().qux() == 'regular qux'
assert Foo().qux.special() == 'special qux'

Of course this will cause problems with wrapping unbound method objects in Python 2.7, but I think your design already breaks for normal methods in 2.7, and hopefully you only care about 3.x here anyway.

like image 38
abarnert Avatar answered Jan 30 '23 14:01

abarnert


classmethod is a descriptor that returns a bound method. It doesn't invoke your __get__ method in this process because it can't do so without breaking some contracts of the descriptor protocol. (Namely, the fact that instance should be an instance, not a class.) So your __get__ method not being called is completely expected.

So how do you make it work? Well, think about it: You want both some_instance.bar and SomeClass.bar to return a Special instance. In order to achieve that, you simply apply the @Special decorator last:

class Foo:
    @Special
    @staticmethod
    def baz():
        return 'baz'

    @Special
    @classmethod
    def qux(cls):
        return 'qux'

This gives you full control over if/when/how the decorated function's descriptor protocol is invoked. Now you just need to remove the if instance is None: special case in your __get__ method, because it prevents classmethods from working correctly. (The reason is that classmethod objects are not callable; you have to invoke the descriptor protocol to turn the classmethod object into a function that can be called.) In other words, the Special.__get__ method has to unconditionally call the decorated function's __get__ method, like this:

def __get__(self, instance=None, owner=None):
    return Special(self.func.__get__(instance, owner))

And now all your assertions will pass.

like image 174
Aran-Fey Avatar answered Jan 30 '23 13:01

Aran-Fey