diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-18 20:33:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 20:36:44 -0700 |
commit | 1b2d0fcee82ec501cc692dc735065d73c6b5b834 (patch) | |
tree | 2b29182792e193e9695b796d7e4c92b95ab040c6 /tensorflow/python/eager | |
parent | 9fe177881224571aff0c267593f747f5fd7a2967 (diff) |
First commit for functional while loop.
Supports single and double derivatives but does not supporting nesting yet.
https://github.com/tensorflow/community/pull/13
PiperOrigin-RevId: 213565971
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 4f1a85a274..a68c6ab3b4 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -826,7 +826,12 @@ def _get_defun_inputs_from_args(args): return nest.pack_sequence_as(args, function_inputs) -def func_graph_from_py_func(name, python_func, args, kwds, signature=None): +def func_graph_from_py_func(name, + python_func, + args, + kwds, + signature=None, + func_graph=None): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -841,6 +846,8 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): `kwds` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. + func_graph: Optional. An instance of FuncGraph. If provided, we will use + this graph else a new one is built and returned. Returns: A FuncGraph. @@ -849,7 +856,9 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ - func_graph = FuncGraph(name) + if func_graph is None: + func_graph = FuncGraph(name) + assert isinstance(func_graph, FuncGraph) with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) |