diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/compiler.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/compiler.py | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py new file mode 100644 index 0000000000..9e1b6bdbe8 --- /dev/null +++ b/tensorflow/python/autograph/pyct/compiler.py @@ -0,0 +1,141 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converting AST to code. + +Adapted from Tangent. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# TODO(mdan): Use six for compatibility here. +import atexit +import imp +import os +import tempfile + +import astor +import gast + +from tensorflow.python.autograph.pyct import origin_info + + +def ast_to_source(node, indentation=' '): + """Return the source code of given AST. + + Args: + node: The code to compile, as an AST object. + indentation: The string to use for indentation. + + Returns: + code: The source code generated from the AST object + source_mapping: A mapping between the user and AutoGraph generated code. + """ + if not isinstance(node, (list, tuple)): + node = (node,) + generator = astor.codegen.SourceGenerator(indentation, False, + astor.string_repr.pretty_string) + + for n in node: + if isinstance(n, gast.AST): + n = gast.gast_to_ast(n) + generator.visit(n) + generator.result.append('\n') + + # In some versions of Python, literals may appear as actual values. This + # ensures everything is string. + code = map(str, generator.result) + code = astor.source_repr.pretty_source(code).lstrip() + + return code + + +def ast_to_object(nodes, + indentation=' ', + include_source_map=False, + source_prefix=None, + delete_on_exit=True): + """Return the Python objects represented by given AST. + + Compiling the AST code this way ensures that the source code is readable by + e.g. `pdb` or `inspect`. + + Args: + nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST + object. + indentation: Text, the string to use for indentation. + include_source_map: bool, whether to attach a source map to the compiled + object. Also see origin_info.py. + source_prefix: Optional[Text], string to print as-is into the source file. + delete_on_exit: bool, whether to delete the temporary file used for + compilation on exit. + + Returns: + compiled_nodes: A module object containing the compiled source code. + source: The source code of the compiled object + Raises: + ValueError: If ag_source_map__ is already in the namespace of the compiled + nodes. + """ + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + source = ast_to_source(nodes, indentation=indentation) + + if source_prefix: + source = source_prefix + '\n' + source + + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + module_name = os.path.basename(f.name[:-3]) + f.write(source) + + if isinstance(nodes, (list, tuple)): + indices = range(-len(nodes), 0) + else: + indices = (-1,) + + if include_source_map: + source_map = origin_info.source_map(nodes, source, f.name, indices) + + # TODO(mdan): Try flush() and delete=False instead. + if delete_on_exit: + atexit.register(lambda: os.remove(f.name)) + compiled_nodes = imp.load_source(module_name, f.name) + + # TODO(znado): Clean this up so we don't need to attach it to the namespace. + # TODO(znado): This does not work for classes because their methods share a + # namespace. + # This attaches the source map which is needed for error handling. Note that + # api.to_graph copies this source map into an attribute of the function. + # + # We need this so the ag_source_map__ variable is available to the call to + # rewrite_graph_construction_error in the except block inside each function + # that handles graph construction errors. + # + # We cannot get the rewritten function name until it is too late so templating + # is hard, and this cleanly fixes the + # issues encountered with nested functions because this is attached to the + # outermost one. + if include_source_map: + # TODO(mdan): This name should be decided by the caller. + source_map_name = 'ag_source_map__' + if source_map_name in compiled_nodes.__dict__: + raise ValueError('cannot convert %s because is has namespace attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_nodes, source_map_name)) + compiled_nodes.__dict__[source_map_name] = source_map + + return compiled_nodes, source |