aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/impl/conversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/impl/conversion.py')
-rw-r--r--tensorflow/python/autograph/impl/conversion.py351
1 files changed, 351 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
new file mode 100644
index 0000000000..928ff9e7ea
--- /dev/null
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -0,0 +1,351 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Core conversion logic, serves as main point of access."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+
+import gast
+
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.converters import directives
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.util import tf_inspect
+
+
+# TODO(mdan): Might we not need any renaming at all?
+
+
+def is_whitelisted_for_graph(o):
+ """Check whether an entity is whitelisted for use in graph mode.
+
+ Examples of whitelisted entities include all members of the tensorflow
+ package.
+
+ Args:
+ o: A Python entity.
+ Returns:
+ Boolean
+ """
+ m = tf_inspect.getmodule(o)
+ for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
+ if m.__name__.startswith(prefix):
+ return True
+ if hasattr(o, 'autograph_info__'):
+ return True
+ return False
+
+
+def entity_to_graph(o, program_ctx, arg_values, arg_types):
+ """Compile a Python entity into equivalent TensorFlow.
+
+ The function will also recursively compile all the entities that `o`
+ references, updating `dependency_cache`.
+
+ This function is reentrant, and relies on dependency_cache to avoid
+ generating duplicate code.
+
+ Args:
+ o: A Python entity.
+ program_ctx: A ProgramContext object.
+ arg_values: A dict containing value hints for symbols like function
+ parameters.
+ arg_types: A dict containing type hints for symbols like function
+ parameters.
+
+ Returns:
+ A tuple (ast, new_name, namespace):
+ * ast: An AST representing an entity with interface equivalent to `o`,
+ but which when executed it creates TF a graph.
+ * new_name: The symbol name under which the new entity can be found.
+ * namespace: A dict mapping all symbols visible to the converted entity,
+ keyed by their symbol name.
+
+ Raises:
+ ValueError: if the entity type is not supported.
+ """
+ if tf_inspect.isclass(o):
+ node, name, ns = class_to_graph(o, program_ctx)
+ elif tf_inspect.isfunction(o):
+ # TODO(mdan): This is not a reliable mechanism.
+ # The most reliable way is to check the source code, the AST will contain
+ # a Lambda node instead of a FunctionDef
+ if o.__name__ == '<lambda>':
+ raise NotImplementedError(
+ 'lambda functions are not yet supported; declare the function'
+ ' using def instead: %s' % o)
+ else:
+ node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
+ elif tf_inspect.ismethod(o):
+ node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
+ # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
+ elif hasattr(o, '__class__'):
+ raise NotImplementedError(
+ 'Object conversion is not yet supported. If you are '
+ 'trying to convert code that uses an existing object, '
+ 'try including the creation of that object in the '
+ 'conversion. For example, instead of converting the method '
+ 'of a class, try converting the entire class instead. '
+ 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
+ 'contrib/autograph/README.md#using-the-functional-api '
+ 'for more information.')
+ else:
+ raise ValueError(
+ 'Entity "%s" has unsupported type "%s". Only functions and classes are '
+ 'supported for now.' % (o, type(o)))
+
+ # TODO(mdan): This is temporary. it should be created using a converter.
+ # TODO(mdan): The attribute should be added with a helper, not directly.
+ # The helper can ensure there are no collisions.
+ template = '''
+ entity.autograph_info__ = {}
+ '''
+ node.extend(templates.replace(template, entity=name))
+
+ program_ctx.add_to_cache(o, node)
+
+ if program_ctx.recursive:
+ while True:
+ candidate = None
+ for obj in program_ctx.name_map.keys():
+ if obj not in program_ctx.dependency_cache:
+ candidate = obj
+ break
+ if candidate is None:
+ break
+ if (hasattr(candidate, 'im_class') and
+ getattr(candidate, 'im_class') not in program_ctx.partial_types):
+ # Class members are converted with their objects, unless they're
+ # only converted partially.
+ continue
+ entity_to_graph(candidate, program_ctx, {}, {})
+
+ return node, name, ns
+
+
+def class_to_graph(c, program_ctx):
+ """Specialization of `entity_to_graph` for classes."""
+ converted_members = {}
+ method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
+ members = tf_inspect.getmembers(c, predicate=method_filter)
+ if not members:
+ raise ValueError('Cannot convert %s: it has no member methods.' % c)
+
+ class_namespace = {}
+ for _, m in members:
+ # Only convert the members that are directly defined by the class.
+ if inspect_utils.getdefiningclass(m, c) is not c:
+ continue
+ node, _, namespace = function_to_graph(
+ m,
+ program_ctx=program_ctx,
+ arg_values={},
+ arg_types={'self': (c.__name__, c)},
+ owner_type=c,
+ rewrite_errors=False)
+ if class_namespace is None:
+ class_namespace = namespace
+ else:
+ class_namespace.update(namespace)
+ converted_members[m] = node[0]
+ namer = program_ctx.new_namer(class_namespace)
+ class_name = namer.compiled_class_name(c.__name__, c)
+
+ # TODO(mdan): This needs to be explained more thoroughly.
+ # Process any base classes: if the superclass if of a whitelisted type, an
+ # absolute import line is generated. Otherwise, it is marked for conversion
+ # (as a side effect of the call to namer.compiled_class_name() followed by
+ # program_ctx.update_name_map(namer)).
+ output_nodes = []
+ renames = {}
+ base_names = []
+ for base in c.__bases__:
+ if isinstance(object, base):
+ base_names.append('object')
+ continue
+ if is_whitelisted_for_graph(base):
+ alias = namer.new_symbol(base.__name__, ())
+ output_nodes.append(
+ gast.ImportFrom(
+ module=base.__module__,
+ names=[gast.alias(name=base.__name__, asname=alias)],
+ level=0))
+ else:
+ # This will trigger a conversion into a class with this name.
+ alias = namer.compiled_class_name(base.__name__, base)
+ base_names.append(alias)
+ renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
+ program_ctx.update_name_map(namer)
+
+ # Generate the definition of the converted class.
+ bases = [gast.Name(n, gast.Load(), None) for n in base_names]
+ class_def = gast.ClassDef(
+ class_name,
+ bases=bases,
+ keywords=[],
+ body=list(converted_members.values()),
+ decorator_list=[])
+ # Make a final pass to replace references to the class or its base classes.
+ # Most commonly, this occurs when making super().__init__() calls.
+ # TODO(mdan): Making direct references to superclass' superclass will fail.
+ class_def = qual_names.resolve(class_def)
+ renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
+ class_def = ast_util.rename_symbols(class_def, renames)
+
+ output_nodes.append(class_def)
+
+ return output_nodes, class_name, class_namespace
+
+
+def _add_reserved_symbol(namespace, name, entity):
+ if name not in namespace:
+ namespace[name] = entity
+ elif namespace[name] != entity:
+ raise ValueError('The name "%s" is reserved and may not be used.' % name)
+
+
+ag_internal = None
+
+
+def _add_self_references(namespace, autograph_module):
+ """Adds namespace references to the module that exposes the api itself."""
+ global ag_internal
+ if ag_internal is None:
+ # Craft a module that exposes parts of the external API as well as certain
+ # internal modules.
+ ag_internal = imp.new_module('autograph')
+ ag_internal.converted_call = autograph_module.converted_call
+ ag_internal.utils = utils
+ ag_internal.rewrite_graph_construction_error = (
+ errors.rewrite_graph_construction_error)
+ # TODO(mdan): Add safeguards against name clashes.
+ # We don't want to create a submodule because we want the operators to be
+ # accessible as ag__.<operator>
+ ag_internal.__dict__.update(operators.__dict__)
+
+ _add_reserved_symbol(namespace, 'ag__', ag_internal)
+
+
+def function_to_graph(f,
+ program_ctx,
+ arg_values,
+ arg_types,
+ owner_type=None,
+ rewrite_errors=True):
+ """Specialization of `entity_to_graph` for callable functions."""
+
+ node, source = parser.parse_entity(f)
+ node = node.body[0]
+ origin_info.resolve(node, source, f)
+ namespace = inspect_utils.getnamespace(f)
+ _add_self_references(namespace, program_ctx.autograph_module)
+ namer = program_ctx.new_namer(namespace)
+
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file='<fragment>',
+ namespace=namespace,
+ arg_values=arg_values,
+ arg_types=arg_types,
+ owner_type=owner_type)
+ context = converter.EntityContext(namer, entity_info, program_ctx)
+ node = node_to_graph(node, context, rewrite_errors=rewrite_errors)
+
+ # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
+ new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
+ if not did_rename:
+ new_name = f.__name__
+ if node.name != f.__name__:
+ raise NotImplementedError('Strange corner case. Send us offending code!')
+ node.name = new_name
+
+ program_ctx.update_name_map(namer)
+ # TODO(mdan): Use this at compilation.
+
+ return [node], new_name, namespace
+
+
+def node_to_graph(node, context, rewrite_errors=True):
+ """Convert Python code to equivalent TF graph mode code.
+
+ Args:
+ node: AST, the code to convert.
+ context: converter.EntityContext
+ rewrite_errors: Boolean, whether or not to rewrite the error traceback.
+
+ Returns:
+ A tuple (node, deps):
+ * node: A Python ast node, representing the converted code.
+ * deps: A set of strings, the fully qualified names of entity
+ dependencies that this node has.
+ """
+ # TODO(mdan): Insert list_comprehensions somewhere.
+
+ node = converter.standard_analysis(node, context, is_initial=True)
+ # Past this point, line numbers are no longer accurate so we ignore the
+ # source.
+ # TODO(mdan): Is it feasible to reconstruct intermediate source code?
+ context.info.source_code = None
+
+ node = converter.apply_(node, context, decorators)
+ node = converter.apply_(node, context, directives)
+ node = converter.apply_(node, context, break_statements)
+ node = converter.apply_(node, context, asserts)
+ # Note: sequencing continue canonicalization before for loop one avoids
+ # dealing with the extra loop increment operation that the for
+ # canonicalization creates.
+ node = converter.apply_(node, context, continue_statements)
+ context.info.namespace['len'] = len
+ node = converter.apply_(node, context, return_statements)
+ node = converter.apply_(node, context, lists)
+ node = converter.apply_(node, context, slices)
+ node = converter.apply_(node, context, builtin_functions)
+ node = converter.apply_(node, context, call_trees)
+ node = converter.apply_(node, context, control_flow)
+ node = converter.apply_(node, context, conditional_expressions)
+ node = converter.apply_(node, context, logical_expressions)
+ node = converter.apply_(node, context, side_effect_guards)
+ node = converter.apply_(node, context, name_scopes)
+ if rewrite_errors:
+ node = converter.apply_(node, context, error_handlers)
+ return node