diff options
Diffstat (limited to 'tensorflow/python/autograph/impl/conversion.py')
-rw-r--r-- | tensorflow/python/autograph/impl/conversion.py | 351 |
1 files changed, 351 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py new file mode 100644 index 0000000000..928ff9e7ea --- /dev/null +++ b/tensorflow/python/autograph/impl/conversion.py @@ -0,0 +1,351 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Core conversion logic, serves as main point of access.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import imp + +import gast + +from tensorflow.python.autograph import operators +from tensorflow.python.autograph import utils +from tensorflow.python.autograph.converters import asserts +from tensorflow.python.autograph.converters import break_statements +from tensorflow.python.autograph.converters import builtin_functions +from tensorflow.python.autograph.converters import call_trees +from tensorflow.python.autograph.converters import conditional_expressions +from tensorflow.python.autograph.converters import continue_statements +from tensorflow.python.autograph.converters import control_flow +from tensorflow.python.autograph.converters import decorators +from tensorflow.python.autograph.converters import directives +from tensorflow.python.autograph.converters import error_handlers +from tensorflow.python.autograph.converters import lists +from tensorflow.python.autograph.converters import logical_expressions +from tensorflow.python.autograph.converters import name_scopes +from tensorflow.python.autograph.converters import return_statements +from tensorflow.python.autograph.converters import side_effect_guards +from tensorflow.python.autograph.converters import slices +from tensorflow.python.autograph.core import config +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.core import errors +from tensorflow.python.autograph.pyct import ast_util +from tensorflow.python.autograph.pyct import inspect_utils +from tensorflow.python.autograph.pyct import origin_info +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import templates +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.util import tf_inspect + + +# TODO(mdan): Might we not need any renaming at all? + + +def is_whitelisted_for_graph(o): + """Check whether an entity is whitelisted for use in graph mode. + + Examples of whitelisted entities include all members of the tensorflow + package. + + Args: + o: A Python entity. + Returns: + Boolean + """ + m = tf_inspect.getmodule(o) + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + return True + if hasattr(o, 'autograph_info__'): + return True + return False + + +def entity_to_graph(o, program_ctx, arg_values, arg_types): + """Compile a Python entity into equivalent TensorFlow. + + The function will also recursively compile all the entities that `o` + references, updating `dependency_cache`. + + This function is reentrant, and relies on dependency_cache to avoid + generating duplicate code. + + Args: + o: A Python entity. + program_ctx: A ProgramContext object. + arg_values: A dict containing value hints for symbols like function + parameters. + arg_types: A dict containing type hints for symbols like function + parameters. + + Returns: + A tuple (ast, new_name, namespace): + * ast: An AST representing an entity with interface equivalent to `o`, + but which when executed it creates TF a graph. + * new_name: The symbol name under which the new entity can be found. + * namespace: A dict mapping all symbols visible to the converted entity, + keyed by their symbol name. + + Raises: + ValueError: if the entity type is not supported. + """ + if tf_inspect.isclass(o): + node, name, ns = class_to_graph(o, program_ctx) + elif tf_inspect.isfunction(o): + # TODO(mdan): This is not a reliable mechanism. + # The most reliable way is to check the source code, the AST will contain + # a Lambda node instead of a FunctionDef + if o.__name__ == '<lambda>': + raise NotImplementedError( + 'lambda functions are not yet supported; declare the function' + ' using def instead: %s' % o) + else: + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + elif tf_inspect.ismethod(o): + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + # TODO(mdan,yashkatariya): Remove when object conversion is implemented. + elif hasattr(o, '__class__'): + raise NotImplementedError( + 'Object conversion is not yet supported. If you are ' + 'trying to convert code that uses an existing object, ' + 'try including the creation of that object in the ' + 'conversion. For example, instead of converting the method ' + 'of a class, try converting the entire class instead. ' + 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' + 'contrib/autograph/README.md#using-the-functional-api ' + 'for more information.') + else: + raise ValueError( + 'Entity "%s" has unsupported type "%s". Only functions and classes are ' + 'supported for now.' % (o, type(o))) + + # TODO(mdan): This is temporary. it should be created using a converter. + # TODO(mdan): The attribute should be added with a helper, not directly. + # The helper can ensure there are no collisions. + template = ''' + entity.autograph_info__ = {} + ''' + node.extend(templates.replace(template, entity=name)) + + program_ctx.add_to_cache(o, node) + + if program_ctx.recursive: + while True: + candidate = None + for obj in program_ctx.name_map.keys(): + if obj not in program_ctx.dependency_cache: + candidate = obj + break + if candidate is None: + break + if (hasattr(candidate, 'im_class') and + getattr(candidate, 'im_class') not in program_ctx.partial_types): + # Class members are converted with their objects, unless they're + # only converted partially. + continue + entity_to_graph(candidate, program_ctx, {}, {}) + + return node, name, ns + + +def class_to_graph(c, program_ctx): + """Specialization of `entity_to_graph` for classes.""" + converted_members = {} + method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) + members = tf_inspect.getmembers(c, predicate=method_filter) + if not members: + raise ValueError('Cannot convert %s: it has no member methods.' % c) + + class_namespace = {} + for _, m in members: + # Only convert the members that are directly defined by the class. + if inspect_utils.getdefiningclass(m, c) is not c: + continue + node, _, namespace = function_to_graph( + m, + program_ctx=program_ctx, + arg_values={}, + arg_types={'self': (c.__name__, c)}, + owner_type=c, + rewrite_errors=False) + if class_namespace is None: + class_namespace = namespace + else: + class_namespace.update(namespace) + converted_members[m] = node[0] + namer = program_ctx.new_namer(class_namespace) + class_name = namer.compiled_class_name(c.__name__, c) + + # TODO(mdan): This needs to be explained more thoroughly. + # Process any base classes: if the superclass if of a whitelisted type, an + # absolute import line is generated. Otherwise, it is marked for conversion + # (as a side effect of the call to namer.compiled_class_name() followed by + # program_ctx.update_name_map(namer)). + output_nodes = [] + renames = {} + base_names = [] + for base in c.__bases__: + if isinstance(object, base): + base_names.append('object') + continue + if is_whitelisted_for_graph(base): + alias = namer.new_symbol(base.__name__, ()) + output_nodes.append( + gast.ImportFrom( + module=base.__module__, + names=[gast.alias(name=base.__name__, asname=alias)], + level=0)) + else: + # This will trigger a conversion into a class with this name. + alias = namer.compiled_class_name(base.__name__, base) + base_names.append(alias) + renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) + program_ctx.update_name_map(namer) + + # Generate the definition of the converted class. + bases = [gast.Name(n, gast.Load(), None) for n in base_names] + class_def = gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[]) + # Make a final pass to replace references to the class or its base classes. + # Most commonly, this occurs when making super().__init__() calls. + # TODO(mdan): Making direct references to superclass' superclass will fail. + class_def = qual_names.resolve(class_def) + renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) + class_def = ast_util.rename_symbols(class_def, renames) + + output_nodes.append(class_def) + + return output_nodes, class_name, class_namespace + + +def _add_reserved_symbol(namespace, name, entity): + if name not in namespace: + namespace[name] = entity + elif namespace[name] != entity: + raise ValueError('The name "%s" is reserved and may not be used.' % name) + + +ag_internal = None + + +def _add_self_references(namespace, autograph_module): + """Adds namespace references to the module that exposes the api itself.""" + global ag_internal + if ag_internal is None: + # Craft a module that exposes parts of the external API as well as certain + # internal modules. + ag_internal = imp.new_module('autograph') + ag_internal.converted_call = autograph_module.converted_call + ag_internal.utils = utils + ag_internal.rewrite_graph_construction_error = ( + errors.rewrite_graph_construction_error) + # TODO(mdan): Add safeguards against name clashes. + # We don't want to create a submodule because we want the operators to be + # accessible as ag__.<operator> + ag_internal.__dict__.update(operators.__dict__) + + _add_reserved_symbol(namespace, 'ag__', ag_internal) + + +def function_to_graph(f, + program_ctx, + arg_values, + arg_types, + owner_type=None, + rewrite_errors=True): + """Specialization of `entity_to_graph` for callable functions.""" + + node, source = parser.parse_entity(f) + node = node.body[0] + origin_info.resolve(node, source, f) + namespace = inspect_utils.getnamespace(f) + _add_self_references(namespace, program_ctx.autograph_module) + namer = program_ctx.new_namer(namespace) + + entity_info = transformer.EntityInfo( + source_code=source, + source_file='<fragment>', + namespace=namespace, + arg_values=arg_values, + arg_types=arg_types, + owner_type=owner_type) + context = converter.EntityContext(namer, entity_info, program_ctx) + node = node_to_graph(node, context, rewrite_errors=rewrite_errors) + + # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py + new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) + if not did_rename: + new_name = f.__name__ + if node.name != f.__name__: + raise NotImplementedError('Strange corner case. Send us offending code!') + node.name = new_name + + program_ctx.update_name_map(namer) + # TODO(mdan): Use this at compilation. + + return [node], new_name, namespace + + +def node_to_graph(node, context, rewrite_errors=True): + """Convert Python code to equivalent TF graph mode code. + + Args: + node: AST, the code to convert. + context: converter.EntityContext + rewrite_errors: Boolean, whether or not to rewrite the error traceback. + + Returns: + A tuple (node, deps): + * node: A Python ast node, representing the converted code. + * deps: A set of strings, the fully qualified names of entity + dependencies that this node has. + """ + # TODO(mdan): Insert list_comprehensions somewhere. + + node = converter.standard_analysis(node, context, is_initial=True) + # Past this point, line numbers are no longer accurate so we ignore the + # source. + # TODO(mdan): Is it feasible to reconstruct intermediate source code? + context.info.source_code = None + + node = converter.apply_(node, context, decorators) + node = converter.apply_(node, context, directives) + node = converter.apply_(node, context, break_statements) + node = converter.apply_(node, context, asserts) + # Note: sequencing continue canonicalization before for loop one avoids + # dealing with the extra loop increment operation that the for + # canonicalization creates. + node = converter.apply_(node, context, continue_statements) + context.info.namespace['len'] = len + node = converter.apply_(node, context, return_statements) + node = converter.apply_(node, context, lists) + node = converter.apply_(node, context, slices) + node = converter.apply_(node, context, builtin_functions) + node = converter.apply_(node, context, call_trees) + node = converter.apply_(node, context, control_flow) + node = converter.apply_(node, context, conditional_expressions) + node = converter.apply_(node, context, logical_expressions) + node = converter.apply_(node, context, side_effect_guards) + node = converter.apply_(node, context, name_scopes) + if rewrite_errors: + node = converter.apply_(node, context, error_handlers) + return node |