Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Remove type hints in Python source programmatically

Tags:

I have some source code written for Python 3.5 that I want to make executable under Python 3.4. The only feature from 3.5 that I use which is not available in 3.4 are type hints, so I'd like to write a script to remove those entirely.

This seemed easy enough at first glance and I decided to write some regexes to do this, but then I thought about a few edge cases and I wasn't sure how to solve the problem for a more complex function like this:

def foo(bar: Dict[T, List[T]],
        baz: Callable[[T], int] = lambda x: (x+3)/7,
        **kwargs) -> List[T]:

Basically, I'd have to parse the whole thing and rebuild the argument list without type annotations. How would I even approach this?

like image 746
klamann Avatar asked Mar 11 '17 10:03

klamann


2 Answers

OK, I got it :D

Use Python's builtin ast module to parse the source code and then the excellent astunparse library to generate source code from the parsed ast again. Then all that's left is to remove the type annotations:

import ast
import astunparse

source="""
import typing
from typing import Dict, T, Callable
from typing import List

def foo(bar: Dict[T, List[T]],
        baz: Callable[[T], int] = lambda x: (x+3)/7,
        **kwargs) -> List[T]:
    pass
"""

class TypeHintRemover(ast.NodeTransformer):

    def visit_FunctionDef(self, node):
        # remove the return type defintion
        node.returns = None
        # remove all argument annotations
        if node.args.args:
            for arg in node.args.args:
                arg.annotation = None
        return node

    def visit_Import(self, node):
        node.names = [n for n in node.names if n.name != 'typing']
        return node if node.names else None

    def visit_ImportFrom(self, node):
        return node if node.module != 'typing' else None

# parse the source code into an AST
parsed_source = ast.parse(source)
# remove all type annotations, function return type definitions
# and import statements from 'typing'
transformed = TypeHintRemover().visit(parsed_source)
# convert the AST back to source code
print(astunparse.unparse(transformed))

The TypeHintRemover visits all Nodes in the AST and removes all type hints within function arguments, the return type definitions of each function and all import statements that refer to the 'typing' module.

The result is:

def foo(bar, baz=(lambda x: ((x + 3) / 7)), **kwargs):
    pass
like image 198
klamann Avatar answered Oct 13 '22 11:10

klamann


You can subclass lib2to3.refactor.RefactoringTool to refactor the code using a fixer that is a subclass of lib2to3.fixer_base.BaseFix with a pattern that looks for either a typed argument, a function declaration with an annotated returning value, or a simple statement that imports or imports from typing, and a transform method that removes the indices of the annotations from the child nodes or replaces the statement node with an empty node:

from lib2to3 import fixer_base, refactor, fixer_util

class FixParameterAnnotations(fixer_base.BaseFix):
    PATTERN = r'''
        name=tname
        |
        func=funcdef< any+ '->' any+ >
        |
        simple_stmt<
            (
                import_name< 'import' 'typing' >
                |
                import_from< 'from' 'typing' 'import' any+ >
            ) '\n'
        >
    '''

    def transform(self, node, results):
        if 'name' in results:
            del node.children[1:] # delete annotation to typed argument
        elif 'func' in results:
            del node.children[-4:-2] # delete annotation to function declaration
        else:
            return fixer_util.BlankLine() # delete statement that imports typing
        return node

class Refactor(refactor.RefactoringTool):
    def __init__(self, fixers):
        self._fixers= [cls(None, None) for cls in fixers]
        super().__init__(None, {'print_function': True})

    def get_fixers(self):
        return self._fixers, []

so that:

source = """
import typing
from typing import Dict, T, Callable
from typing import List

def foo(bar: Dict[T, List[T]],
        baz: Callable[[T], int] = lambda x: (x+3)/7,
        **kwargs) -> List[T]:
    print(line, end="")    # comments and white spaces are preserved
"""
print(Refactor([FixParameterAnnotations]).refactor_string(source, ''))

outputs:

def foo(bar,
        baz = lambda x: (x+3)/7,
        **kwargs):
    print(line, end="")    # comments and white spaces are preserved

Demo: https://repl.it/@blhsing/BurlywoodFeistyTrials

As a bonus, lib2to3 also preserves all comments and white spaces after the transformation. You can find the definition of the Python grammar in Grammar.txt of the lib2to3 module.

like image 37
blhsing Avatar answered Oct 13 '22 11:10

blhsing