Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Extract Python function source text from the source code string

Tags:

python

Suppose I have valid Python source code, as a string:

code_string = """
# A comment.
def foo(a, b):
  return a + b
class Bar(object):
  def __init__(self):
    self.my_list = [
        'a',
        'b',
    ]
""".strip()

Objective: I would like to obtain the lines containing the source code of the function definitions, preserving whitespace. For the code string above, I would like to get the strings

def foo(a, b):
  return a + b

and

  def __init__(self):
    self.my_list = [
        'a',
        'b',
    ]

Or, equivalently, I'd be happy to get the line numbers of functions in the code string: foo spans lines 2-3, and __init__ spans lines 5-9.

Attempts

I can parse the code string into its AST:

code_ast = ast.parse(code_string)

And I can find the FunctionDef nodes, e.g.:

function_def_nodes = [node for node in ast.walk(code_ast)
                      if isinstance(node, ast.FunctionDef)]

Each FunctionDef node's lineno attribute tells us the first line for that function. We can estimate the last line of that function with:

last_line = max(node.lineno for node in ast.walk(function_def_node)
                if hasattr(node, 'lineno'))

but this doesn't work perfectly when the function ends with syntactic elements that don't show up as AST nodes, for instance the last ] in __init__.

I doubt there is an approach that only uses the AST, because the AST fundamentally does not have enough information in cases like __init__.

I cannot use the inspect module because that only works on "live objects" and I only have the Python code as a string. I cannot eval the code because that's a huge security headache.

In theory I could write a parser for Python but that really seems like overkill.

A heuristic suggested in the comments is to use the leading whitespace of lines. However, that can break for strange but valid functions with weird indentation like:

def baz():
  return [
1,
  ]

class Baz(object):
  def hello(self, x):
    return self.hello(
x - 1)

def my_type_annotated_function(
  my_long_argument_name: SomeLongArgumentTypeName
) -> SomeLongReturnTypeName:
  # This function's indentation isn't unusual at all.
  pass
like image 564
k_ssb Avatar asked Jan 25 '19 23:01

k_ssb


Video Answer


2 Answers

A much more robust solution would be to use the tokenize module. The following code can handle weird indentations, comments, multi-line tokens, single-line function blocks and empty lines within function blocks:

import tokenize
from io import BytesIO
from collections import deque
code_string = """
# A comment.
def foo(a, b):
  return a + b

class Bar(object):
  def __init__(self):

    self.my_list = [
        'a',
        'b',
    ]

  def test(self): pass
  def abc(self):
    '''multi-
    line token'''

def baz():
  return [
1,
  ]

class Baz(object):
  def hello(self, x):
    a = \
1
    return self.hello(
x - 1)

def my_type_annotated_function(
  my_long_argument_name: SomeLongArgumentTypeName
) -> SomeLongReturnTypeName:
  pass
  # unmatched parenthesis: (
""".strip()
file = BytesIO(code_string.encode())
tokens = deque(tokenize.tokenize(file.readline))
lines = []
while tokens:
    token = tokens.popleft()
    if token.type == tokenize.NAME and token.string == 'def':
        start_line, _ = token.start
        last_token = token
        while tokens:
            token = tokens.popleft()
            if token.type == tokenize.NEWLINE:
                break
            last_token = token
        if last_token.type == tokenize.OP and last_token.string == ':':
            indents = 0
            while tokens:
                token = tokens.popleft()
                if token.type == tokenize.NL:
                    continue
                if token.type == tokenize.INDENT:
                    indents += 1
                elif token.type == tokenize.DEDENT:
                    indents -= 1
                    if not indents:
                        break
                else:
                    last_token = token
        lines.append((start_line, last_token.end[0]))
print(lines)

This outputs:

[(2, 3), (6, 11), (13, 13), (14, 16), (18, 21), (24, 27), (29, 33)]

Note however that the continuation line:

a = \
1

is treated by tokenize as one line even though it is in fact two lines, since if you print the tokens:

TokenInfo(type=53 (OP), string=':', start=(24, 20), end=(24, 21), line='  def hello(self, x):\n')
TokenInfo(type=4 (NEWLINE), string='\n', start=(24, 21), end=(24, 22), line='  def hello(self, x):\n')
TokenInfo(type=5 (INDENT), string='    ', start=(25, 0), end=(25, 4), line='    a = 1\n')
TokenInfo(type=1 (NAME), string='a', start=(25, 4), end=(25, 5), line='    a = 1\n')
TokenInfo(type=53 (OP), string='=', start=(25, 6), end=(25, 7), line='    a = 1\n')
TokenInfo(type=2 (NUMBER), string='1', start=(25, 8), end=(25, 9), line='    a = 1\n')
TokenInfo(type=4 (NEWLINE), string='\n', start=(25, 9), end=(25, 10), line='    a = 1\n')
TokenInfo(type=1 (NAME), string='return', start=(26, 4), end=(26, 10), line='    return self.hello(\n')

you can see that the continuation line is literally treated as one line of ' a = 1\n', with only one line number 25. This is apparently a bug/limitation of the tokenize module unfortunately.

like image 56
blhsing Avatar answered Oct 15 '22 05:10

blhsing


Rather than reinventing a parser, I would use python itself.

Basically I would use the compile() built-in function, which can check if a string is a valid python code by compiling it. I pass to it a string made of selected lines, starting from each def to the farther line which does not fail to compile.

code_string = """
#A comment
def foo(a, b):
  return a + b

def bir(a, b):
  c = a + b
  return c

class Bar(object):
  def __init__(self):
    self.my_list = [
        'a',
        'b',
    ]

def baz():
  return [
1,
  ]

""".strip()

lines = code_string.split('\n')

#looking for lines with 'def' keywords
defidxs = [e[0] for e in enumerate(lines) if 'def' in e[1]]

#getting the indentation of each 'def'
indents = {}
for i in defidxs:
    ll = lines[i].split('def')
    indents[i] = len(ll[0])

#extracting the strings
end = len(lines)-1
while end > 0:
    if end < defidxs[-1]:
        defidxs.pop()
    try:
        start = defidxs[-1]
    except IndexError: #break if there are no more 'def'
        break

    #empty lines between functions will cause an error, let's remove them
    if len(lines[end].strip()) == 0:
        end = end -1
        continue

    try:
        #fix lines removing indentation or compile will not compile
        fixlines = [ll[indents[start]:] for ll in lines[start:end+1]] #remove indentation
        body = '\n'.join(fixlines)
        compile(body, '<string>', 'exec') #if it fails, throws an exception
        print(body)
        end = start #no need to parse less line if it succeed.
    except:
        pass

    end = end -1

It is a bit nasty because of the except clause without specific exceptions, which is usually not recommended, but there is no way to know what may cause compile to fail, so I do not know how to avoid it.

This will prints

def baz():
  return [
1,
  ]
def __init__(self):
  self.my_list = [
      'a',
      'b',
  ]
def bir(a, b):
  c = a + b
  return c
def foo(a, b):
  return a + b

Note that the functions are printed in reverse order than those they appear inside code_strings

This should handle even the weird indentation code, but I think it will fails if you have nested functions.

like image 35
Valentino Avatar answered Oct 15 '22 04:10

Valentino