aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/impl/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/impl/api.py')
-rw-r--r--tensorflow/contrib/autograph/impl/api.py52
1 files changed, 36 insertions, 16 deletions
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index c7401c7df1..4729c735c6 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -23,7 +23,6 @@ from functools import wraps
from enum import Enum
# pylint:disable=g-bad-import-order
-import gast
import six
# pylint:enable=g-bad-import-order
@@ -69,7 +68,8 @@ def convert(recursive=False, verbose=False, arg_types=None):
@wraps(f)
def wrapper(*args, **kwargs):
- return converted_call(f, recursive, verbose, arg_types, *args, **kwargs)
+ return converted_call(f, recursive, verbose, True, arg_types, *args,
+ **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@@ -99,6 +99,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
Returns:
A decorator that wraps the original function.
"""
+
def decorator(f):
"""Decorator implementation."""
@@ -109,8 +110,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
@wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
- raise NotImplementedError(
- 'RunMode.PY_FUNC does not yet support kwargs')
+ raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
# TODO(mdan): Add support for kwargs.
return py_func.wrap_py_func(
f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
@@ -130,12 +130,12 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
return decorator
-def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
+def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
+ **kwargs):
"""Compiles a function call inline."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
-
- if conversion.is_whitelisted_for_graph(f):
+ if not force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
@@ -231,7 +231,10 @@ def to_graph(e,
Returns:
A function with a signature identical to `o`, but which when executed it
- creates TF a graph that has the same functionality as the original entity.
+ creates TF a graph that has the same functionality as the original entity.
+ Raises:
+ ValueError: If the converted function defines or refers to symbol names that
+ are reserved for AutoGraph.
"""
program_ctx = converter.ProgramContext(
recursive=recursive,
@@ -242,24 +245,41 @@ def to_graph(e,
_, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
arg_types)
- module = gast.Module([])
+ nodes = []
for dep in reversed(program_ctx.dependency_cache.values()):
- module.body.append(dep)
- compiled_node, compiled_src = compiler.ast_to_object(
- module, source_prefix=program_ctx.required_imports)
+ nodes.extend(dep)
+ compiled_module, compiled_src = compiler.ast_to_object(
+ nodes,
+ source_prefix=program_ctx.required_imports,
+ include_source_map=True)
# The compiled code should see everything the entry entity saw.
# TODO(mdan): This might not work well if the call tree spans modules?
for key, val in namespace.items():
# Avoid overwriting entities that have been transformed.
- if key not in compiled_node.__dict__:
- compiled_node.__dict__[key] = val
- compiled_fn = getattr(compiled_node, name)
+ if key not in compiled_module.__dict__:
+ compiled_module.__dict__[key] = val
+ compiled = getattr(compiled_module, name)
+
+ # Need this so the source_mapping attribute is available for the context
+ # manager to access for runtime errors.
+ #
+ # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
+ # symbol to the compiled module.
+ # TODO(mdan): Record this statically in the generated code.
+ # TODO(mdan): Rename this attribute to 'autograph_info__'
+ source_map_attribute_name = 'ag_source_map'
+ if getattr(compiled, source_map_attribute_name, None) is not None:
+ raise ValueError('cannot convert %s because is has an attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled, source_map_attribute_name))
+ setattr(compiled, source_map_attribute_name,
+ compiled_module.__dict__['ag_source_map__'])
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
- return compiled_fn
+ return compiled
def to_code(e,