aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/impl/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/impl/api.py')
-rw-r--r--tensorflow/contrib/autograph/impl/api.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index c7401c7df1..f7fe3de5da 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -99,6 +99,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
Returns:
A decorator that wraps the original function.
"""
+
def decorator(f):
"""Decorator implementation."""
@@ -109,8 +110,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
@wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
- raise NotImplementedError(
- 'RunMode.PY_FUNC does not yet support kwargs')
+ raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
# TODO(mdan): Add support for kwargs.
return py_func.wrap_py_func(
f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
@@ -231,7 +231,10 @@ def to_graph(e,
Returns:
A function with a signature identical to `o`, but which when executed it
- creates TF a graph that has the same functionality as the original entity.
+ creates TF a graph that has the same functionality as the original entity.
+ Raises:
+ ValueError: If the converted function defines or refers to symbol names that
+ are reserved for AutoGraph.
"""
program_ctx = converter.ProgramContext(
recursive=recursive,
@@ -256,6 +259,19 @@ def to_graph(e,
compiled_node.__dict__[key] = val
compiled_fn = getattr(compiled_node, name)
+ # Need this so the source_mapping attribute is available for the context
+ # manager to access for runtime errors.
+ #
+ # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
+ # symbol to the compiled module.
+ source_map_attribute_name = 'ag_source_map'
+ if getattr(compiled_fn, source_map_attribute_name, None) is not None:
+ raise ValueError('cannot convert %s because is has an attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled_fn, source_map_attribute_name))
+ setattr(compiled_fn, source_map_attribute_name,
+ compiled_node.__dict__['ag_source_map__'])
+
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
@@ -292,7 +308,7 @@ def to_code(e,
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
- compiler.ast_to_source(dep, indentation)
+ compiler.ast_to_source(dep, indentation)[0]
for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
return program_ctx.required_imports + '\n\n' + code