aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/compiler.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/compiler.py111
1 files changed, 102 insertions, 9 deletions
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py
index 24c4517afa..c90a5e89c2 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/contrib/autograph/pyct/compiler.py
@@ -30,9 +30,58 @@ import tempfile
import astor
import gast
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.contrib.autograph.pyct import parser
+
+
+def _build_source_map(node, code):
+ """Return the Python objects represented by given AST.
+
+ Compiling the AST code this way ensures that the source code is readable by
+ e.g. `pdb` or `inspect`.
+
+ Args:
+ node: An AST node of the original generated code, before the source code is
+ generated.
+ code: The string representation of the source code for the newly generated
+ code.
+
+ Returns:
+ Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
+ generated code.
+ """
+ # After we have the final generated code we reparse it to get the final line
+ # numbers. Then we walk through the generated and original ASTs in parallel
+ # to build the mapping between the user and generated code.
+ new_node = parser.parse_str(code)
+ origin_info.resolve(new_node, code)
+ source_mapping = {}
+ for before, after in ast_util.parallel_walk(node, new_node):
+ # Need both checks because if origin information is ever copied over to new
+ # nodes then we need to rely on the fact that only the original user code
+ # has the origin annotation.
+ if (anno.hasanno(before, anno.Basic.ORIGIN) and
+ anno.hasanno(after, anno.Basic.ORIGIN)):
+ source_info = anno.getanno(before, anno.Basic.ORIGIN)
+ new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
+ source_mapping[new_line_number] = source_info
+ return source_mapping
+
def ast_to_source(node, indentation=' '):
- """Return the source code of given AST."""
+ """Return the source code of given AST.
+
+ Args:
+ node: The code to compile, as an AST object.
+ indentation: The string to use for indentation.
+
+ Returns:
+ code: The source code generated from the AST object
+ source_mapping: A mapping between the user and AutoGraph generated code.
+ """
+ original_node = node
if isinstance(node, gast.AST):
node = gast.gast_to_ast(node)
generator = astor.codegen.SourceGenerator(indentation, False,
@@ -42,11 +91,16 @@ def ast_to_source(node, indentation=' '):
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
code = map(str, generator.result)
- return astor.source_repr.pretty_source(code).lstrip()
+ code = astor.source_repr.pretty_source(code).lstrip()
+ source_mapping = _build_source_map(original_node, code)
+
+ return code, source_mapping
-def ast_to_object(
- node, indentation=' ', source_prefix=None, delete_on_exit=True):
+def ast_to_object(node,
+ indentation=' ',
+ source_prefix=None,
+ delete_on_exit=True):
"""Return the Python objects represented by given AST.
Compiling the AST code this way ensures that the source code is readable by
@@ -56,15 +110,31 @@ def ast_to_object(
node: The code to compile, as an AST object.
indentation: The string to use for indentation.
source_prefix: Optional string to print as-is into the source file.
- delete_on_exit: Whether to delete the temporary file used for compilation
- on exit.
+ delete_on_exit: Whether to delete the temporary file used for compilation on
+ exit.
Returns:
- A module object containing the compiled source code.
+ compiled_node: A module object containing the compiled source code.
+ source: The source code of the compiled object
+ Raises:
+ ValueError: If ag_source_map__ is already in the namespace of the compiled
+ node.
"""
- source = ast_to_source(node, indentation)
+ # code_source_mapping does not yet include the offsets from import statements.
+ source, code_source_mapping = ast_to_source(node, indentation=indentation)
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
+ # TODO(znado): move into an _offset_source_map() helper function.
+ # Need to offset the generated line numbers by the number of import lines.
+ if source_prefix:
+ num_import_lines = source_prefix.count('\n') + 1
+ else:
+ num_import_lines = 0
+ source_mapping = {}
+ for line_number, original_position in code_source_mapping.items():
+ source_map_key = origin_info.CodeLocation(
+ file_path=f.name, line_number=line_number + num_import_lines)
+ source_mapping[source_map_key] = original_position
module_name = os.path.basename(f.name[:-3])
if source_prefix:
f.write(source_prefix)
@@ -72,4 +142,27 @@ def ast_to_object(
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
- return imp.load_source(module_name, f.name), source
+ compiled_node = imp.load_source(module_name, f.name)
+
+ # TODO(znado): Clean this up so we don't need to attach it to the namespace.
+ # TODO(znado): This does not work for classes because their methods share a
+ # namespace.
+ # This attaches the source map which is needed for error handling. Note that
+ # api.to_graph copies this source map into an attribute of the function.
+ #
+ # We need this so the ag_source_map__ variable is available to the call to
+ # rewrite_graph_construction_error in the except block inside each function
+ # that handles graph construction errors.
+ #
+ # We cannot get the rewritten function name until it is too late so templating
+ # is hard, and this cleanly fixes the
+ # issues encountered with nested functions because this is attached to the
+ # outermost one.
+ source_map_name = 'ag_source_map__'
+ if source_map_name in compiled_node.__dict__:
+ raise ValueError('cannot convert %s because is has namespace attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled_node, source_map_name))
+ compiled_node.__dict__[source_map_name] = source_mapping
+
+ return compiled_node, source