diff options
Diffstat (limited to 'tensorflow/contrib/autograph/impl/conversion.py')
-rw-r--r-- | tensorflow/contrib/autograph/impl/conversion.py | 76 |
1 files changed, 41 insertions, 35 deletions
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 776d19f672..7bd0ba3f2d 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -28,26 +28,27 @@ from tensorflow.contrib.autograph.converters import asserts from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import conditional_expressions from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import decorators -from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import directives +from tensorflow.contrib.autograph.converters import error_handlers from tensorflow.contrib.autograph.converters import lists from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import return_statements from tensorflow.contrib.autograph.converters import side_effect_guards -from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.static_analysis import activity -from tensorflow.contrib.autograph.pyct.static_analysis import live_values -from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect @@ -157,7 +158,8 @@ def class_to_graph(c, program_ctx): program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, - owner_type=c) + owner_type=c, + rewrite_errors=False) if class_namespace is None: class_namespace = namespace else: @@ -231,6 +233,8 @@ def _add_self_references(namespace, autograph_module): ag_internal = imp.new_module('autograph') ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils + ag_internal.rewrite_graph_construction_error = ( + errors.rewrite_graph_construction_error) # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be # accessible as ag__.<operator> @@ -239,11 +243,17 @@ def _add_self_references(namespace, autograph_module): _add_reserved_symbol(namespace, 'ag__', ag_internal) -def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): +def function_to_graph(f, + program_ctx, + arg_values, + arg_types, + owner_type=None, + rewrite_errors=True): """Specialization of `entity_to_graph` for callable functions.""" + node, source = parser.parse_entity(f) node = node.body[0] - + origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) @@ -256,7 +266,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) - node = node_to_graph(node, context) + node = node_to_graph(node, context, rewrite_errors=rewrite_errors) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -272,22 +282,13 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): return node, new_name, namespace -def _apply_transformer(node, context, converter_module): - # TODO(mdan): Clear static analysis here. - node = qual_names.resolve(node) - node = activity.resolve(node, context.info, None) - node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) - node = type_info.resolve(node, context.info) - node = converter_module.transform(node, context) - return node - - -def node_to_graph(node, context): +def node_to_graph(node, context, rewrite_errors=True): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext + rewrite_errors: Boolean, whether or not to rewrite the error traceback. Returns: A tuple (node, deps): @@ -295,28 +296,33 @@ def node_to_graph(node, context): * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ - # TODO(mdan): Verify arguments for correctness. + # TODO(mdan): Insert list_comprehensions somewhere. - node = _apply_transformer(node, context, ifexp) + node = converter.standard_analysis(node, context, is_initial=True) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? context.info.source_code = None - node = _apply_transformer(node, context, decorators) - node = _apply_transformer(node, context, break_statements) - node = _apply_transformer(node, context, asserts) + + node = converter.apply_(node, context, decorators) + node = converter.apply_(node, context, directives) + node = converter.apply_(node, context, break_statements) + node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = _apply_transformer(node, context, continue_statements) + node = converter.apply_(node, context, continue_statements) context.info.namespace['len'] = len - node = _apply_transformer(node, context, single_return) - node = _apply_transformer(node, context, lists) - node = _apply_transformer(node, context, slices) - node = _apply_transformer(node, context, builtin_functions) - node = _apply_transformer(node, context, call_trees) - node = _apply_transformer(node, context, control_flow) - node = _apply_transformer(node, context, logical_expressions) - node = _apply_transformer(node, context, side_effect_guards) - node = _apply_transformer(node, context, name_scopes) + node = converter.apply_(node, context, return_statements) + node = converter.apply_(node, context, lists) + node = converter.apply_(node, context, slices) + node = converter.apply_(node, context, builtin_functions) + node = converter.apply_(node, context, call_trees) + node = converter.apply_(node, context, control_flow) + node = converter.apply_(node, context, conditional_expressions) + node = converter.apply_(node, context, logical_expressions) + node = converter.apply_(node, context, side_effect_guards) + node = converter.apply_(node, context, name_scopes) + if rewrite_errors: + node = converter.apply_(node, context, error_handlers) return node |