I have a decorator @pure
that registers a function as pure, for example:
@pure
def rectangle_area(a,b):
return a*b
@pure
def triangle_area(a,b,c):
return ((a+(b+c))(c-(a-b))(c+(a-b))(a+(b-c)))**0.5/4
Next, I want to identify a newly defined pure function
def house_area(a,b,c):
return rectangle_area(a,b) + triangle_area(a,b,c)
Obviously house_area
is pure, since it only calls pure functions.
How can I discover all pure functions automatically (perhaps by using ast
)
Assuming operators are all pure, then essentially you only need to check all the functions calls. This can indeed be done with the ast module.
First I defined the pure
decorator as:
def pure(f):
f.pure = True
return f
Adding an attribute telling that it's pure, allows skipping early or "forcing" a function to identify as pure. This is useful if you'd need a function like math.sin
to identify as pure. Additionally since you can't add attributes to builtin functions.
@pure
def sin(x):
return math.sin(x)
All in all. Use the ast
module to visit all the nodes. Then for each Call
node check whether the function being called is pure.
import ast
class PureVisitor(ast.NodeVisitor):
def __init__(self, visited):
super().__init__()
self.pure = True
self.visited = visited
def visit_Name(self, node):
return node.id
def visit_Attribute(self, node):
name = [node.attr]
child = node.value
while child is not None:
if isinstance(child, ast.Attribute):
name.append(child.attr)
child = child.value
else:
name.append(child.id)
break
name = ".".join(reversed(name))
return name
def visit_Call(self, node):
if not self.pure:
return
name = self.visit(node.func)
if name not in self.visited:
self.visited.append(name)
try:
callee = eval(name)
if not is_pure(callee, self.visited):
self.pure = False
except NameError:
self.pure = False
Then check whether the function has the pure
attribute. If not get code and check if all the functions calls can be classified as pure.
import inspect, textwrap
def is_pure(f, _visited=None):
try:
return f.pure
except AttributeError:
pass
try:
code = inspect.getsource(f.__code__)
except AttributeError:
return False
code = textwrap.dedent(code)
node = compile(code, "<unknown>", "exec", ast.PyCF_ONLY_AST)
if _visited is None:
_visited = []
visitor = PureVisitor(_visited)
visitor.visit(node)
return visitor.pure
Note that print(is_pure(lambda x: math.sin(x)))
doesn't work since inspect.getsource(f.__code__)
returns code on a line by line basis. So the source returned by getsource
would include the print
and is_pure
call, thus yielding False
. Unless those functions are overridden.
To verify that it works, test it by doing:
print(house_area) # Prints: True
To list through all the functions in the current module:
import sys, types
for k in dir(sys.modules[__name__]):
v = globals()[k]
if isinstance(v, types.FunctionType):
print(k, is_pure(v))
The visited
list keeps track of which functions have already been verified pure. This help circumvent problems related to recursion. Since the code isn't executed, the evaluation would recursively visit factorial
.
@pure
def factorial(n):
return 1 if n == 1 else n * factorial(n - 1)
Note that you might need to revise the following code. Choosing another way to obtain a function from its name.
try:
callee = eval(name)
if not is_pure(callee, self.visited):
self.pure = False
except NameError:
self.pure = False
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With