diff options
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/origin_info.py')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/origin_info.py | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index b3c6a43d37..614e346634 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -17,10 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import namedtuple +import collections +import gast -class CodeLocation(namedtuple('CodeLocation', ('file_path', 'line_number'))): +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.python.util import tf_inspect + + +class CodeLocation( + collections.namedtuple('CodeLocation', ('file_path', 'line_number'))): """Location of a line of code. Attributes: @@ -31,8 +37,9 @@ class CodeLocation(namedtuple('CodeLocation', ('file_path', 'line_number'))): class OriginInfo( - namedtuple('OriginInfo', ('file_path', 'function_name', 'line_number', - 'column_offset', 'source_code_line'))): + collections.namedtuple('OriginInfo', + ('file_path', 'function_name', 'line_number', + 'column_offset', 'source_code_line'))): """Container for information about the source code before conversion. Instances of this class contain information about the source code that @@ -50,3 +57,44 @@ class OriginInfo( """ return (self.file_path, self.line_number, self.function_name, self.source_code_line) + + +# TODO(znado): Consider refactoring this into a Visitor. +def resolve(node, source, function=None): + """Adds an origin information to all nodes inside the body of function. + + Args: + node: The AST node for the function whose body nodes will be annotated. + source: Text, the source code string for the function whose body nodes will + be annotated. + function: Callable, the function that will have all nodes inside of it + annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If + it is None then only the line numbers and column offset will be set in the + annotation, with the rest of the information being None. + + Returns: + A tuple of the AST node for function and a String containing its source + code. + """ + if function: + _, function_lineno = tf_inspect.getsourcelines(function) + function_filepath = tf_inspect.getsourcefile(function) + else: + function_lineno = None + function_filepath = None + source_lines = source.split('\n') + for n in gast.walk(node): + if hasattr(n, 'lineno'): + # n.lineno is relative to the start of the enclosing function, so need to + # offset it by the line of the function. + source_code_line = source_lines[n.lineno - 1] + if function: + source_lineno = n.lineno + function_lineno - 1 + function_name = function.__name__ + else: + source_lineno = n.lineno + function_name = None + anno.setanno( + n, anno.Basic.ORIGIN, + OriginInfo(function_filepath, function_name, source_lineno, + n.col_offset, source_code_line)) |