aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/origin_info.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/origin_info.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py56
1 files changed, 52 insertions, 4 deletions
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
index b3c6a43d37..614e346634 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -17,10 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from collections import namedtuple
+import collections
+import gast
-class CodeLocation(namedtuple('CodeLocation', ('file_path', 'line_number'))):
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.util import tf_inspect
+
+
+class CodeLocation(
+ collections.namedtuple('CodeLocation', ('file_path', 'line_number'))):
"""Location of a line of code.
Attributes:
@@ -31,8 +37,9 @@ class CodeLocation(namedtuple('CodeLocation', ('file_path', 'line_number'))):
class OriginInfo(
- namedtuple('OriginInfo', ('file_path', 'function_name', 'line_number',
- 'column_offset', 'source_code_line'))):
+ collections.namedtuple('OriginInfo',
+ ('file_path', 'function_name', 'line_number',
+ 'column_offset', 'source_code_line'))):
"""Container for information about the source code before conversion.
Instances of this class contain information about the source code that
@@ -50,3 +57,44 @@ class OriginInfo(
"""
return (self.file_path, self.line_number, self.function_name,
self.source_code_line)
+
+
+# TODO(znado): Consider refactoring this into a Visitor.
+def resolve(node, source, function=None):
+ """Adds an origin information to all nodes inside the body of function.
+
+ Args:
+ node: The AST node for the function whose body nodes will be annotated.
+ source: Text, the source code string for the function whose body nodes will
+ be annotated.
+ function: Callable, the function that will have all nodes inside of it
+ annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If
+ it is None then only the line numbers and column offset will be set in the
+ annotation, with the rest of the information being None.
+
+ Returns:
+ A tuple of the AST node for function and a String containing its source
+ code.
+ """
+ if function:
+ _, function_lineno = tf_inspect.getsourcelines(function)
+ function_filepath = tf_inspect.getsourcefile(function)
+ else:
+ function_lineno = None
+ function_filepath = None
+ source_lines = source.split('\n')
+ for n in gast.walk(node):
+ if hasattr(n, 'lineno'):
+ # n.lineno is relative to the start of the enclosing function, so need to
+ # offset it by the line of the function.
+ source_code_line = source_lines[n.lineno - 1]
+ if function:
+ source_lineno = n.lineno + function_lineno - 1
+ function_name = function.__name__
+ else:
+ source_lineno = n.lineno
+ function_name = None
+ anno.setanno(
+ n, anno.Basic.ORIGIN,
+ OriginInfo(function_filepath, function_name, source_lineno,
+ n.col_offset, source_code_line))