aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/origin_info.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/origin_info.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py100
1 files changed, 100 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
new file mode 100644
index 0000000000..614e346634
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -0,0 +1,100 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Container for origin source code information before AutoGraph compilation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.util import tf_inspect
+
+
+class CodeLocation(
+ collections.namedtuple('CodeLocation', ('file_path', 'line_number'))):
+ """Location of a line of code.
+
+ Attributes:
+ file_path: text, the full path to the file containing the code.
+ line_number: Int, the 1-based line number of the code in its file.
+ """
+ pass
+
+
+class OriginInfo(
+ collections.namedtuple('OriginInfo',
+ ('file_path', 'function_name', 'line_number',
+ 'column_offset', 'source_code_line'))):
+ """Container for information about the source code before conversion.
+
+ Instances of this class contain information about the source code that
+ transformed code originated from. Examples include:
+ * line number
+ * file name
+ * original user code
+ """
+
+ def as_frame(self):
+ """Makes a traceback frame tuple.
+
+ Returns:
+ A tuple of (file_path, line_number, function_name, source_code_line).
+ """
+ return (self.file_path, self.line_number, self.function_name,
+ self.source_code_line)
+
+
+# TODO(znado): Consider refactoring this into a Visitor.
+def resolve(node, source, function=None):
+ """Adds an origin information to all nodes inside the body of function.
+
+ Args:
+ node: The AST node for the function whose body nodes will be annotated.
+ source: Text, the source code string for the function whose body nodes will
+ be annotated.
+ function: Callable, the function that will have all nodes inside of it
+ annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If
+ it is None then only the line numbers and column offset will be set in the
+ annotation, with the rest of the information being None.
+
+ Returns:
+ A tuple of the AST node for function and a String containing its source
+ code.
+ """
+ if function:
+ _, function_lineno = tf_inspect.getsourcelines(function)
+ function_filepath = tf_inspect.getsourcefile(function)
+ else:
+ function_lineno = None
+ function_filepath = None
+ source_lines = source.split('\n')
+ for n in gast.walk(node):
+ if hasattr(n, 'lineno'):
+ # n.lineno is relative to the start of the enclosing function, so need to
+ # offset it by the line of the function.
+ source_code_line = source_lines[n.lineno - 1]
+ if function:
+ source_lineno = n.lineno + function_lineno - 1
+ function_name = function.__name__
+ else:
+ source_lineno = n.lineno
+ function_name = None
+ anno.setanno(
+ n, anno.Basic.ORIGIN,
+ OriginInfo(function_filepath, function_name, source_lineno,
+ n.col_offset, source_code_line))