diff options
Diffstat (limited to 'tensorflow/contrib/autograph/impl/api.py')
-rw-r--r-- | tensorflow/contrib/autograph/impl/api.py | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index c7401c7df1..f7fe3de5da 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -99,6 +99,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): Returns: A decorator that wraps the original function. """ + def decorator(f): """Decorator implementation.""" @@ -109,8 +110,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): @wraps(f) def py_func_wrapper(*args, **kwargs): if kwargs: - raise NotImplementedError( - 'RunMode.PY_FUNC does not yet support kwargs') + raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func( f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) @@ -231,7 +231,10 @@ def to_graph(e, Returns: A function with a signature identical to `o`, but which when executed it - creates TF a graph that has the same functionality as the original entity. + creates TF a graph that has the same functionality as the original entity. + Raises: + ValueError: If the converted function defines or refers to symbol names that + are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, @@ -256,6 +259,19 @@ def to_graph(e, compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) + # Need this so the source_mapping attribute is available for the context + # manager to access for runtime errors. + # + # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' + # symbol to the compiled module. + source_map_attribute_name = 'ag_source_map' + if getattr(compiled_fn, source_map_attribute_name, None) is not None: + raise ValueError('cannot convert %s because is has an attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_fn, source_map_attribute_name)) + setattr(compiled_fn, source_map_attribute_name, + compiled_node.__dict__['ag_source_map__']) + if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) @@ -292,7 +308,7 @@ def to_code(e, conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( - compiler.ast_to_source(dep, indentation) + compiler.ast_to_source(dep, indentation)[0] for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code |