Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Identifying pure functions in python

Tags:

I have a decorator @pure that registers a function as pure, for example:

@pure
def rectangle_area(a,b):
    return a*b


@pure
def triangle_area(a,b,c):
    return ((a+(b+c))(c-(a-b))(c+(a-b))(a+(b-c)))**0.5/4

Next, I want to identify a newly defined pure function

def house_area(a,b,c):
    return rectangle_area(a,b) + triangle_area(a,b,c)

Obviously house_area is pure, since it only calls pure functions.

How can I discover all pure functions automatically (perhaps by using ast)

like image 385
Uri Goren Avatar asked Aug 09 '17 13:08

Uri Goren


1 Answers

Assuming operators are all pure, then essentially you only need to check all the functions calls. This can indeed be done with the ast module.

First I defined the pure decorator as:

def pure(f):
    f.pure = True
    return f

Adding an attribute telling that it's pure, allows skipping early or "forcing" a function to identify as pure. This is useful if you'd need a function like math.sin to identify as pure. Additionally since you can't add attributes to builtin functions.

@pure
def sin(x):
    return math.sin(x)

All in all. Use the ast module to visit all the nodes. Then for each Call node check whether the function being called is pure.

import ast

class PureVisitor(ast.NodeVisitor):
    def __init__(self, visited):
        super().__init__()
        self.pure = True
        self.visited = visited

    def visit_Name(self, node):
        return node.id

    def visit_Attribute(self, node):
        name = [node.attr]
        child = node.value
        while child is not None:
            if isinstance(child, ast.Attribute):
                name.append(child.attr)
                child = child.value
            else:
                name.append(child.id)
                break
        name = ".".join(reversed(name))
        return name

    def visit_Call(self, node):
        if not self.pure:
            return
        name = self.visit(node.func)
        if name not in self.visited:
            self.visited.append(name)
            try:
                callee = eval(name)
                if not is_pure(callee, self.visited):
                    self.pure = False
            except NameError:
                self.pure = False

Then check whether the function has the pure attribute. If not get code and check if all the functions calls can be classified as pure.

import inspect, textwrap

def is_pure(f, _visited=None):
    try:
        return f.pure
    except AttributeError:
        pass

    try:
        code = inspect.getsource(f.__code__)
    except AttributeError:
        return False

    code = textwrap.dedent(code)
    node = compile(code, "<unknown>", "exec", ast.PyCF_ONLY_AST)

    if _visited is None:
        _visited = []

    visitor = PureVisitor(_visited)
    visitor.visit(node)
    return visitor.pure

Note that print(is_pure(lambda x: math.sin(x))) doesn't work since inspect.getsource(f.__code__) returns code on a line by line basis. So the source returned by getsource would include the print and is_pure call, thus yielding False. Unless those functions are overridden.


To verify that it works, test it by doing:

print(house_area) # Prints: True

To list through all the functions in the current module:

import sys, types

for k in dir(sys.modules[__name__]):
    v = globals()[k]
    if isinstance(v, types.FunctionType):
        print(k, is_pure(v))

The visited list keeps track of which functions have already been verified pure. This help circumvent problems related to recursion. Since the code isn't executed, the evaluation would recursively visit factorial.

@pure
def factorial(n):
    return 1 if n == 1 else n * factorial(n - 1)

Note that you might need to revise the following code. Choosing another way to obtain a function from its name.

try:
    callee = eval(name)
    if not is_pure(callee, self.visited):
        self.pure = False
except NameError:
    self.pure = False
like image 75
vallentin Avatar answered Oct 11 '22 12:10

vallentin