Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Automatically refactor python lambdas to named functions

Tags:

python

I am working on the purescript-python project and there are several core libraries that make extensive use of lambdas. Because of the way the code is compiled, the location of the lambdas winds up getting obscured, which results in bizarre console messages whenever an error occurs.

I would like to refactor these libraries to use lambdas as little as possible. So, for example, if there is something like:

def foo(a):
  return lambda b: lambda c: lambda d: lambda e: a + b + c + d + e

it would be nice to generate

def foo(a):
  def _foo_internal_anon_1(b):
    def _foo_internal_anon_2(c):
        def _foo_internal_anon_3(d):
          def _foo_internal_anon_4(e):
            return a + b + c + d + e
          return _foo_internal_anon_4
        return _foo_internal_anon_3
    return _foo_internal_anon_2
  return _foo_internal_anon_1

Is there a way to do this, ie with pylint or a vscode or pycharm plugin, or does this have to be done by hand?

like image 775
Mike Avatar asked Jul 01 '20 07:07

Mike


Video Answer


1 Answers

You can use a custom ast.NodeTransformer to transform Lambda inside Return nodes to full function definitions. The transformed AST can then be unparsed with help of the unparse.py tool from the CPython repo (starting with Python 3.9 you can also use ast.unparse). This allows to transform a whole script, not just single functions.

This is the node transformer:

import ast
from contextlib import contextmanager


@contextmanager
def resetattr(obj, name, value):
    old_value = getattr(obj, name)
    setattr(obj, name, value)
    yield
    setattr(obj, name, old_value)


class ConvertLambda(ast.NodeTransformer):
    def __init__(self):
        super().__init__()
        self.base_name = None
        self.n = 0

    def visit_FunctionDef(self, node):
        if isinstance(node.body[-1], ast.Return) and isinstance(node.body[-1].value, ast.Lambda):
            lambda_node = node.body[-1].value
            with resetattr(self, 'base_name', self.base_name or node.name):
                with resetattr(self, 'n', self.n+1):
                    func_name = f'_{self.base_name}_internal_anon_{self.n}'
                    func_def = ast.FunctionDef(
                        name=func_name,
                        args=lambda_node.args,
                        body=[ast.Return(value=lambda_node.body)],
                        decorator_list=[],
                        returns=None,
                    )
                    self.visit(func_def)
            node.body.insert(-1, func_def)
            node.body[-1].value = ast.Name(id=func_name)
        return node

It can be used together with the Unparser class as follows (or alternatively ast.unparse for Python 3.9+):

from unparse import Unparser

def convert_func_def(text):
    tree = ast.parse(text)
    tree = ast.fix_missing_locations(ConvertLambda().visit(tree))
    Unparser(tree)

By default this prints the result to sys.stdout but Unparser can be configured to use any file-like object: Unparser(tree, file=...).

This is the result obtained for the example function:

def foo(a):

    def _foo_internal_anon_1(b):

        def _foo_internal_anon_2(c):

            def _foo_internal_anon_3(d):

                def _foo_internal_anon_4(e):
                    return ((((a + b) + c) + d) + e)
                return _foo_internal_anon_4
            return _foo_internal_anon_3
        return _foo_internal_anon_2
    return _foo_internal_anon_1

It adds some additional blank lines and parentheses around the addition but this can also be customized by modifying the Unparser class.

like image 71
a_guest Avatar answered Sep 29 '22 20:09

a_guest