aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/impl/conversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/impl/conversion.py')
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py76
1 files changed, 41 insertions, 35 deletions
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 776d19f672..7bd0ba3f2d 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -28,26 +28,27 @@ from tensorflow.contrib.autograph.converters import asserts
from tensorflow.contrib.autograph.converters import break_statements
from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.converters import call_trees
+from tensorflow.contrib.autograph.converters import conditional_expressions
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import ifexp
+from tensorflow.contrib.autograph.converters import directives
+from tensorflow.contrib.autograph.converters import error_handlers
from tensorflow.contrib.autograph.converters import lists
from tensorflow.contrib.autograph.converters import logical_expressions
from tensorflow.contrib.autograph.converters import name_scopes
+from tensorflow.contrib.autograph.converters import return_statements
from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.converters import single_return
from tensorflow.contrib.autograph.converters import slices
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.core import errors
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import inspect_utils
+from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
from tensorflow.python.util import tf_inspect
@@ -157,7 +158,8 @@ def class_to_graph(c, program_ctx):
program_ctx=program_ctx,
arg_values={},
arg_types={'self': (c.__name__, c)},
- owner_type=c)
+ owner_type=c,
+ rewrite_errors=False)
if class_namespace is None:
class_namespace = namespace
else:
@@ -231,6 +233,8 @@ def _add_self_references(namespace, autograph_module):
ag_internal = imp.new_module('autograph')
ag_internal.converted_call = autograph_module.converted_call
ag_internal.utils = utils
+ ag_internal.rewrite_graph_construction_error = (
+ errors.rewrite_graph_construction_error)
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
@@ -239,11 +243,17 @@ def _add_self_references(namespace, autograph_module):
_add_reserved_symbol(namespace, 'ag__', ag_internal)
-def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
+def function_to_graph(f,
+ program_ctx,
+ arg_values,
+ arg_types,
+ owner_type=None,
+ rewrite_errors=True):
"""Specialization of `entity_to_graph` for callable functions."""
+
node, source = parser.parse_entity(f)
node = node.body[0]
-
+ origin_info.resolve(node, source, f)
namespace = inspect_utils.getnamespace(f)
_add_self_references(namespace, program_ctx.autograph_module)
namer = program_ctx.new_namer(namespace)
@@ -256,7 +266,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
arg_types=arg_types,
owner_type=owner_type)
context = converter.EntityContext(namer, entity_info, program_ctx)
- node = node_to_graph(node, context)
+ node = node_to_graph(node, context, rewrite_errors=rewrite_errors)
# TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
@@ -272,22 +282,13 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
return node, new_name, namespace
-def _apply_transformer(node, context, converter_module):
- # TODO(mdan): Clear static analysis here.
- node = qual_names.resolve(node)
- node = activity.resolve(node, context.info, None)
- node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
- node = type_info.resolve(node, context.info)
- node = converter_module.transform(node, context)
- return node
-
-
-def node_to_graph(node, context):
+def node_to_graph(node, context, rewrite_errors=True):
"""Convert Python code to equivalent TF graph mode code.
Args:
node: AST, the code to convert.
context: converter.EntityContext
+ rewrite_errors: Boolean, whether or not to rewrite the error traceback.
Returns:
A tuple (node, deps):
@@ -295,28 +296,33 @@ def node_to_graph(node, context):
* deps: A set of strings, the fully qualified names of entity
dependencies that this node has.
"""
- # TODO(mdan): Verify arguments for correctness.
+ # TODO(mdan): Insert list_comprehensions somewhere.
- node = _apply_transformer(node, context, ifexp)
+ node = converter.standard_analysis(node, context, is_initial=True)
# Past this point, line numbers are no longer accurate so we ignore the
# source.
# TODO(mdan): Is it feasible to reconstruct intermediate source code?
context.info.source_code = None
- node = _apply_transformer(node, context, decorators)
- node = _apply_transformer(node, context, break_statements)
- node = _apply_transformer(node, context, asserts)
+
+ node = converter.apply_(node, context, decorators)
+ node = converter.apply_(node, context, directives)
+ node = converter.apply_(node, context, break_statements)
+ node = converter.apply_(node, context, asserts)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
- node = _apply_transformer(node, context, continue_statements)
+ node = converter.apply_(node, context, continue_statements)
context.info.namespace['len'] = len
- node = _apply_transformer(node, context, single_return)
- node = _apply_transformer(node, context, lists)
- node = _apply_transformer(node, context, slices)
- node = _apply_transformer(node, context, builtin_functions)
- node = _apply_transformer(node, context, call_trees)
- node = _apply_transformer(node, context, control_flow)
- node = _apply_transformer(node, context, logical_expressions)
- node = _apply_transformer(node, context, side_effect_guards)
- node = _apply_transformer(node, context, name_scopes)
+ node = converter.apply_(node, context, return_statements)
+ node = converter.apply_(node, context, lists)
+ node = converter.apply_(node, context, slices)
+ node = converter.apply_(node, context, builtin_functions)
+ node = converter.apply_(node, context, call_trees)
+ node = converter.apply_(node, context, control_flow)
+ node = converter.apply_(node, context, conditional_expressions)
+ node = converter.apply_(node, context, logical_expressions)
+ node = converter.apply_(node, context, side_effect_guards)
+ node = converter.apply_(node, context, name_scopes)
+ if rewrite_errors:
+ node = converter.apply_(node, context, error_handlers)
return node