aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-18 20:33:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 20:36:44 -0700
commit1b2d0fcee82ec501cc692dc735065d73c6b5b834 (patch)
tree2b29182792e193e9695b796d7e4c92b95ab040c6 /tensorflow/python/eager
parent9fe177881224571aff0c267593f747f5fd7a2967 (diff)
First commit for functional while loop.
Supports single and double derivatives but does not supporting nesting yet. https://github.com/tensorflow/community/pull/13 PiperOrigin-RevId: 213565971
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 4f1a85a274..a68c6ab3b4 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -826,7 +826,12 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+ python_func,
+ args,
+ kwds,
+ signature=None,
+ func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -841,6 +846,8 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
`kwds` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
+ func_graph: Optional. An instance of FuncGraph. If provided, we will use
+ this graph else a new one is built and returned.
Returns:
A FuncGraph.
@@ -849,7 +856,9 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(name)
+ if func_graph is None:
+ func_graph = FuncGraph(name)
+ assert isinstance(func_graph, FuncGraph)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)