aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/BUILD23
-rw-r--r--tensorflow/python/eager/function.py13
-rw-r--r--tensorflow/python/kernel_tests/BUILD25
-rw-r--r--tensorflow/python/kernel_tests/while_v2_test.py252
-rw-r--r--tensorflow/python/ops/while_v2.py573
-rw-r--r--tensorflow/tools/pip_package/BUILD1
6 files changed, 885 insertions, 2 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2eeae773d3..d70e9c5798 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1998,6 +1998,29 @@ py_library(
)
py_library(
+ name = "while_v2",
+ srcs = [
+ "ops/while_v2.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":cond_v2_impl",
+ ":constant_op",
+ ":control_flow_util",
+ ":framework_ops",
+ ":function_def_to_graph",
+ ":functional_ops_gen",
+ ":gradients_impl",
+ ":list_ops",
+ ":tensor_shape",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:function",
+ ],
+)
+
+py_library(
name = "cond_v2_impl",
srcs = [
"ops/cond_v2_impl.py",
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 4f1a85a274..a68c6ab3b4 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -826,7 +826,12 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+ python_func,
+ args,
+ kwds,
+ signature=None,
+ func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -841,6 +846,8 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
`kwds` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
+ func_graph: Optional. An instance of FuncGraph. If provided, we will use
+ this graph else a new one is built and returned.
Returns:
A FuncGraph.
@@ -849,7 +856,9 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(name)
+ if func_graph is None:
+ func_graph = FuncGraph(name)
+ assert isinstance(func_graph, FuncGraph)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 100240a626..a048eaa69f 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3204,3 +3204,28 @@ tf_py_test(
grpc_enabled = True,
tags = ["no_gpu"], # TODO(b/111656070)
)
+
+# TODO(b/116053459): Replace with cuda_py_test.
+tf_py_test(
+ name = "while_v2_test",
+ size = "medium",
+ srcs = ["while_v2_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients_impl",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tf_optimizer",
+ "//tensorflow/python:while_v2",
+ ],
+ grpc_enabled = True,
+ tags = ["no_gpu"], # TODO(b/116053459)
+)
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
new file mode 100644
index 0000000000..d00e39d482
--- /dev/null
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -0,0 +1,252 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for while_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import while_v2
+from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
+from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
+from tensorflow.python.platform import test
+
+
+class WhileV2Test(test.TestCase, parameterized.TestCase):
+
+ def testSingleLoopVar(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testMultipleLoopVarsBasic(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y])
+ # ret = [x*y^2, y]
+
+ # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
+ grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
+ with self.test_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [45., 3.])
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testMultipleLoopVars(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ # y = x + y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w),
+ [x, y])
+ # ret = [y*x**2 + x*y**2, x*y + x + y]
+
+ gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
+ gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
+ gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
+ grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
+ grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
+ grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
+ with self.test_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [120., 23.])
+ self.assertSequenceEqual(sess.run(gradx_0), [39.])
+ self.assertSequenceEqual(sess.run(gradx_1), [4.])
+ self.assertSequenceEqual(sess.run(gradx_2), [43.])
+ self.assertSequenceEqual(sess.run(grady_0), [55.])
+ self.assertSequenceEqual(sess.run(grady_1), [6.])
+ self.assertSequenceEqual(sess.run(grady_2), [61.])
+
+ def testMultipleWhileLoops(self):
+ x = constant_op.constant(2.)
+ ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2
+ ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
+ grad = gradients_impl.gradients(ret2, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.test_session() as sess:
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testDoubleDerivative(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
+ grad = gradients_impl.gradients(ret, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testPruning(self):
+ x = constant_op.constant(1)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=x.dtype, element_shape=x.shape)
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5
+
+ def Body(x, tl):
+ return x + 1, list_ops.tensor_list_push_back(tl, x)
+
+ outputs = while_loop_v1(Cond, Body, [x, tensor_list])
+
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(outputs[0])
+
+ def GetOptimizedGraph():
+ mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+ return tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
+
+ stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
+ train_op.append(stack)
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
+
+ def testCaptureExternalTensorInCond(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(1.)
+ ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testCaptureExternalTensorInBody(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(3.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testLoopWithTensorListPushBack(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ tl = list_ops.tensor_list_push_back(tl, x)
+ tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testDuplicateAccumulator(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ # There is an accumulator in the loop already so we should not add
+ # another.
+ tl = list_ops.tensor_list_push_back(tl, x)
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+
+ for op in ops.get_default_graph().get_operations():
+ if op.type == "While":
+ while_op = op
+
+ body_graph = while_v2._get_body_graph(while_op)
+ # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
+ x_input_t = body_graph.inputs[1]
+ accumulator_count = len(
+ [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
+ self.assertEqual(accumulator_count, 1)
+
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ @parameterized.named_parameters(
+ ("Unknown shape", None),
+ ("Partially defined shape", [None]),
+ ("Fully defined shape", [1, 2]),
+ )
+ def testTensorListOutputElementShape(self, shape):
+ self.skipTest("b/115982901")
+ x = constant_op.constant(2.)
+ y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
+ ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y])
+
+ # Get the TensorList output of While op containing the accumulated values
+ # of y.
+ while_op = ret[0].op
+ body_graph = while_v2._get_body_graph(while_op)
+ # body_graph.inputs: [counter_arg, x_arg, y_arg, *accumulators]
+ y_input_t = body_graph.inputs[2]
+ push_back_node = [c for c in y_input_t.consumers()
+ if c.type == "TensorListPushBack"][0]
+ output_idx = body_graph.outputs.index(push_back_node.outputs[0])
+ output = while_op.outputs[output_idx]
+
+ _, val = list_ops.tensor_list_pop_back(output,
+ element_dtype=dtypes.float32)
+ self.assertEqual(val.shape, tensor_shape.TensorShape(shape))
+
+
+def ScalarShape():
+ return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
new file mode 100644
index 0000000000..801217fe66
--- /dev/null
+++ b/tensorflow/python/ops/while_v2.py
@@ -0,0 +1,573 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""while_v2 and gradient.
+
+This is a version of while_loop that emits a single While op, as well as the
+gradient function for While ops produced by while_loop. This will eventually
+replace the current tf.while_loop implementation once it reaches feature and
+performance parity.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.util import nest
+
+# pylint: disable=protected-access
+
+# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
+# control dependencies on external nodes with at least 1 output.
+# Another idea is to create const nodes outside the loop and add control edges
+# to them and then pass those in as data inputs. This should probably be
+# handled in the CapturingGraph itself.
+
+
+def while_loop(cond, body, loop_vars, name=None):
+ """Like tf.while_loop, except emits a single While op."""
+ if not name:
+ name = "while"
+
+ with ops.name_scope(name) as scope:
+ with ops.name_scope(None):
+ cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
+ body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
+
+ flattened_loop_vars = nest.flatten(loop_vars)
+ num_outputs = len(flattened_loop_vars)
+
+ # Add loop counter needed for computing gradients.
+ flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
+ ] + flattened_loop_vars
+
+ # Build a `cond` wrapper that can handle the extra counter loop_var.
+ def wrapped_cond(unused_loop_counter, *loop_vars):
+ return cond(*loop_vars)
+
+ cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
+ flattened_loop_vars, {})
+
+ # Add external_captures of cond to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+
+ def wrapped_body(loop_counter, *args):
+ """Loop body augmented with counter update.
+
+ Args:
+ loop_counter: Loop counter which needs to be incremented in the body.
+ *args: List of args
+ args[:num_outputs] - Args for the original loop body.
+ args[num_outputs:] - External captures of cond. These get passed
+ through as is.
+
+ Returns:
+ A list of tensors the same length as args.
+ """
+ outputs = body(*args[:num_outputs])
+ if not isinstance(outputs, collections.Sequence):
+ outputs = [outputs]
+
+ # Return the external_captures of cond_graph as is, i.e., treat them as
+ # loop invariants.
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
+
+ body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
+ flattened_loop_vars, {})
+ # Add external captures of body to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ body_graph.outputs.extend(body_graph.internal_captures)
+
+ # Capture `external_captures` of `body_graph` in `cond_graph` so that it
+ # expects to receive those as arguments.
+ # TODO(srbs): Dedup tensors that are captured in both the cond and body.
+ # This logic already exists in cond_v2.
+ with cond_graph.as_default():
+ for external_capture in body_graph.external_captures:
+ cond_graph.capture(external_capture)
+
+ # Export all tensors in the loop body that may be needed for gradient
+ # computation. We do this by accumulating the intermediate values in
+ # TensorLists.
+ intermediate_tensors = _get_intermediates(body_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ # TODO(srbs): Cache and re-use empty tensor lists.
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(
+ intermediate_tensor.shape))
+ flattened_loop_vars.append(tensor_list)
+ with cond_graph.as_default():
+ # Add a placeholder to cond_graph's inputs corresponding to the
+ # tensor_list.
+ cond_graph.capture(tensor_list)
+ with body_graph.as_default():
+ # Push the intermediate tensor to the tensor list. This captures the
+ # `tensor_list` as well.
+ appended_tensor_list = list_ops.tensor_list_push_back(
+ tensor_list,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_graph.outputs.append(appended_tensor_list)
+
+ outputs = gen_functional_ops._while(
+ flattened_loop_vars,
+ cond_v2._create_new_tf_function(cond_graph),
+ cond_v2._create_new_tf_function(body_graph),
+ name=scope)
+
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # First var is loop counter.
+ if num_outputs == 1:
+ return outputs[1]
+ else:
+ return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
+
+
+@ops.RegisterGradient("While")
+def _WhileGrad(op, *grads): # pylint: disable=invalid-name
+ """The gradient of a While op produced by while_loop."""
+ body_graph = _get_body_graph(op)
+
+ # Replace None gradients with zeros. This is needed because `grads` could have
+ # None incoming gradients for the TensorLists. If we pass None's through, the
+ # custom gradient of TensorListPopBack will create an EmptyTensorList inside
+ # the FuncGraph which is undesirable.
+ # TODO(b/80444525): There might be an issue with treating no gradient as zero
+ # gradient in certain cases. Consider replacing None gradients with Zeros
+ # for accumulators only.
+ grads = [
+ g if g is not None else array_ops.zeros_like(output)
+ for g, output in zip(grads, op.outputs)
+ ]
+
+ body_grad_graph, args = _create_grad_func(
+ body_graph, grads,
+ _get_unique_name("%s_grad" % body_graph.name), op)
+
+ intermediate_tensors = _get_intermediates(body_grad_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
+ with body_grad_graph.as_default():
+ tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
+ # Push the intermediate tensor to the tensor list.
+ appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_grad_graph.outputs.append(appended_tensor_list)
+
+ def grad_cond(counter, max_iters, *unused_args):
+ return counter < max_iters
+
+ loop_vars = args + body_grad_graph.external_captures
+ cond_grad_graph = function.func_graph_from_py_func(
+ _get_unique_name("%s_grad_cond" % op.name),
+ grad_cond, loop_vars, {})
+
+ assert len(loop_vars) == len(body_grad_graph.inputs)
+ assert len(loop_vars) == len(body_grad_graph.outputs)
+ assert len(loop_vars) == len(cond_grad_graph.inputs)
+
+ outputs = gen_functional_ops._while(
+ loop_vars,
+ cond_v2._create_new_tf_function(cond_grad_graph),
+ cond_v2._create_new_tf_function(body_grad_graph),
+ name=_get_unique_name("%s_grad" % op.name))
+
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # outputs[0] is the loop counter.
+ # outputs[1] is the total number of loop iterations.
+ return outputs[2:2 + len(op.inputs)]
+
+
+# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
+def _get_body_graph(while_op):
+ """Returns `FuncGraph` for the while body.
+
+ Args:
+ while_op: The While Operation.
+
+ Returns:
+ `FuncGraph` for the while body.
+ """
+ extra_inputs = list(while_op.inputs)
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = while_op.get_attr("body").name
+ fdef = while_op.graph._get_function(func_name).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+ func_graph._while = while_op
+ return func_graph
+
+
+def _create_grad_func(func_graph, grads, name, while_op):
+ """Builds and returns the gradient FuncGraph of `func_graph` and its args.
+
+ The returned grad_func_graph must be called with the returned
+ args + grad_func_graph.captures.
+
+ Args:
+ func_graph: FuncGraph for the forward body function.
+ grads: The incoming grads for `func_graph`'s outputs.
+ name: Name of the returned gradient function.
+ while_op: The forward While op.
+
+ Returns:
+ 2-tuple of (grad_func_graph, args).
+ """
+ assert len(func_graph.outputs) == len(grads)
+
+ loop_counter = constant_op.constant(0.)
+ # TODO(srbs): For nested while loops will need to lookup this value from
+ # the accumulator of the enclosing while loop. For now use as is assuming
+ # there is no nesting.
+ num_iters_t = while_op.outputs[0]
+
+ args = [loop_counter, num_iters_t] + grads
+
+ # Note: The returned function does not have `args` in the list of
+ # `external_captures`.
+ grad_func_graph = function.func_graph_from_py_func(
+ name,
+ lambda *args: _grad_fn(func_graph, args),
+ args, {},
+ func_graph=_WhileBodyGradFuncGraph(name, func_graph))
+
+ # Add the popped accumulators to the list of outputs.
+ for internal_capture in grad_func_graph.internal_captures:
+ grad_func_graph.outputs.append(
+ grad_func_graph.popped_tensor_lists[internal_capture])
+
+ return grad_func_graph, args
+
+
+def _grad_fn(func_graph, args):
+ """Computes the gradient of `func_graph` in the current graph.
+
+ This function builds the gradient graph of the corresponding forward-pass
+ `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
+
+ Args:
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
+ args: The input arguments. args[0] - Loop counter args[1] - Total number of
+ iterations.
+ args[2:] - Incoming gradients for `func_graph.outputs`.
+
+ Returns:
+ The output gradient Tensors.
+ """
+ xs = func_graph.inputs
+ ys = func_graph.outputs
+ grad_ys = args[2:]
+
+ # Build the gradient graph. Note that this builds the gradient computation of
+ # func_graph in the current graph, which requires capturing tensors from
+ # func_graph. The captured func_graph tensors are resolved to external tensors
+ # in _resolve_grad_inputs.
+ # TODO(srbs): Mark GradientsHelper as public?
+ grad_outs = gradients_impl._GradientsHelper(
+ ys, xs, grad_ys=grad_ys, src_graph=func_graph)
+
+ assert all([g is not None for g in grad_outs])
+ counter = args[0]
+ total_iters = args[1]
+ return [counter + 1, total_iters] + grad_outs
+
+
+def _get_intermediates(func_graph):
+ """Returns all tensors in `func_graph` that should be accumulated."""
+ # We currently accumulate output tensors of most ops in the function and rely
+ # on the pruning pass to get rid of the unused accumulators at runtime.
+ # However, this can bloat the GraphDef and make debugging harder so we perform
+ # some optimizations.
+ #
+ # Optimization we currently perform:
+ # 1. We do not accumulate tensors which already have an accumulator
+ # in the loop body.
+ # 2. We do not accumulate outputs of Identity nodes. When building the
+ # FuncGraph, we add an Identity node for each output (see
+ # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+ # of all these nodes bloats the GraphDef quite a bit so we remove those.
+ # Since the gradient of an Identity node does not rely on its forward op's
+ # input this is safe to do.
+ #
+ # Other possible optimizations:
+ # 1. Only accumulate tensors that will be required by the backward pass.
+ # This will require running the gradient pass and hence would increase the
+ # graph building time for the forward pass.
+ # 2. Do not accumulate Const nodes created inside the loop body.
+ # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants.
+ # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer
+ # since it requires knowledge of the while loop semantics. If so, consider
+ # doing those here.
+ intermediates = []
+
+ for op in func_graph.get_operations():
+ if op.type == "Identity":
+ continue
+ for o in op.outputs:
+ if (o != func_graph.inputs[0] and # Loop counter.
+ _get_accumulator(o) is None): # Has existing accumulator.
+ intermediates.append(o)
+ return intermediates
+
+
+def _get_accumulator(tensor):
+ r"""Returns TensorList if any containing accumulated values of tensor.
+
+ We try to find a pattern of the form:
+
+ input_tl tensor
+ \ /
+ (TensorListPushBack)
+ |
+ output_tl
+
+ which satisfies the following conditions:
+
+ 1. input_tl must be in tensor.graph.inputs.
+ 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
+ 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
+
+ output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
+ returned if such a pattern is found else None is returned.
+
+ Args:
+ tensor: The Tensor to be accumulated.
+
+ Returns:
+ A variant tensor in the same graph as `tensor` or None if no accumulator is
+ found.
+ """
+ assert isinstance(tensor.graph, function.FuncGraph)
+
+ def get_func_graph_output(t):
+ """Returns t or Identity(t) whichever exists in graph outputs else None."""
+ if t in tensor.graph.outputs:
+ return t
+ # tf.defun adds an Identity for each output, check whether that is the case.
+ identity_op = t.consumers()[0]
+ if (identity_op.type == "Identity" and
+ identity_op.outputs[0] in tensor.graph.outputs):
+ return identity_op.outputs[0]
+ return None
+
+ for consumer in tensor.consumers():
+ # Find the consumer that is a TensorListPushBack node whose TensorList input
+ # is in the list of function inputs.
+ if (consumer.type != "TensorListPushBack" or
+ consumer.inputs[0] not in tensor.graph.inputs):
+ continue
+
+ output = get_func_graph_output(consumer.outputs[0])
+ if output is None:
+ # The TensorList output of `consumer` is not in the list of function
+ # outputs.
+ continue
+
+ accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
+ accum_output_idx = tensor.graph.outputs.index(output)
+ if accum_input_idx == accum_output_idx:
+ return output
+ return None
+
+
+# TODO(srbs): Add to common utils for cond_v2 and while_v2.
+def _get_unique_name(name):
+ """Returns a name that is unique in the root graph of `func_graph`.
+
+ Args:
+ name: String to uniquify.
+
+ Returns:
+ A string.
+ """
+ with ops.init_scope():
+ return ops.get_default_graph().unique_name(name)
+
+
+class _WhileBodyGradFuncGraph(function.FuncGraph):
+ """FuncGraph for the gradient function of the body of a While op.
+
+ Contains the logic for capturing the tensors from the body of the forward
+ While op which is as follows:
+ 1. Find the accumulator for that tensor.
+ 2. Capture the forward While op output tensor corresponding to the
+ accumulator in this FuncGraph.
+ 3. Pop a value from the captured placeholder and use it as the captured value
+ for the forward pass tensor.
+
+ This only allows capturing tensors in the forward graph. A ValueError is
+ raised if an attempt is made to capture a tensor not in the forward graph.
+ To manually capture capture a tensor that is not in the forward graph, call
+ `capture` with `whitelisted=True`.
+
+ Note: The `captures` dict does not contain the forward tensor since it is not
+ directly captured. It contains the accumulator corresponding to this forward
+ tensor.
+
+ Attributes:
+ popped_tensor_lists: Dict from the captured accumulator placeholder to the
+ TensorList obtained after popping the intermediate tensor from it. The
+ values of this dict need to be added to the list of outputs.
+ """
+
+ def __init__(self, name, forward_graph):
+ super(_WhileBodyGradFuncGraph, self).__init__(name)
+ self.popped_tensor_lists = {}
+ # FuncGraph for the body of the forward While op.
+ self._forward_graph = forward_graph
+ # Dict from forward intermediate tensor to the corresponding "popped" tensor
+ # in this graph.
+ self._indirect_captures = {}
+ # Dict from forward graph tensor to the While op output corresponding to its
+ # accumulator.
+ self._tensor_to_accumulator = {}
+
+ def capture(self, tensor, name=None, whitelisted=False):
+ """Selectively captures external tensors.
+
+ If `whitelisted` is False only allows capturing tensors in the
+ `_forward_graph`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+ whitelisted: If False (default), only allows capturing tensors from the
+ forward graph.
+
+ Returns:
+ The placeholder in this graph for the tensor.
+
+ Raises:
+ ValueError: If attempting to capture an external tensor not in the forward
+ graph with `whitelisted` set to False.
+ """
+ if (not whitelisted and tensor.graph is not self and
+ tensor.graph != self._forward_graph):
+ raise ValueError("Attempting to capture tensor", str(tensor),
+ " which is not in the forward graph but in ",
+ _graph_name(tensor.graph), ".")
+ return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
+
+ def _capture_helper(self, tensor, name):
+ if tensor.graph is not self._forward_graph:
+ return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
+
+ captured_tensor = self._indirect_captures.get(tensor)
+ if captured_tensor is not None:
+ # For GradientTape housekeeping.
+ assert self._tensor_to_accumulator[tensor] in self.captures
+ super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ self._tensor_to_accumulator[tensor], name)
+ return captured_tensor
+
+ assert tensor not in self._tensor_to_accumulator
+
+ accumulator = None
+
+ # Find the TensorList that was used to accumulate the tensors of this
+ # intermediate tensor.
+ accumulator = _get_accumulator(tensor)
+ if accumulator is None:
+ raise ValueError("Reference to un-accumulated intermediate tensor: ",
+ tensor.name)
+ assert accumulator.graph == self._forward_graph
+ # Get the While op output corresponding to the accumulator.
+ accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs
+ .index(accumulator)]
+
+ assert accumulator.graph == self._forward_graph.outer_graph
+ self._tensor_to_accumulator[tensor] = accumulator
+
+ # Capture the `accumulator`.
+ accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ accumulator, name)
+ new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
+ accumulator_ph, element_dtype=tensor.dtype)
+ self._indirect_captures[tensor] = captured_tensor
+ self.popped_tensor_lists[accumulator_ph] = new_tensor_list
+ return captured_tensor
+
+
+# TODO(srbs): Move to common utils for cond_v2 and while_v2.
+def _maybe_set_lowering_attr(op):
+ """Sets the flag to enable lowering on the `While` op if necessary.
+
+ Lowering allows while_v2 to avoid some of the limitations of Functions,
+ allowing users to specify devices & colocation inside of while_v2
+ branches, and enabling non-strict evaluation & partial pruning of while_v2
+ branches. This brings while_v2 closer to feature parity with
+ tf.while_loop.
+
+ However, we do not lower `While` in the XLA context because it is easier
+ for XLA to apply its own optimizations when dealing with un-lowered
+ `While` operators than with low-level control flow primitives.
+
+ Args:
+ op: The While op.
+ """
+ if not control_flow_util.IsInXLAContext(op):
+ # pylint: disable=protected-access
+ op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
+ # pylint: enable=protected-access
+
+
+def _get_tensor_convertible_shape(shape):
+ assert isinstance(shape, tensor_shape.TensorShape)
+ if shape.is_fully_defined():
+ return shape
+ if not shape: # Unknown shape.
+ return -1
+ # Partially defined shape.
+ shape_list = shape.as_list()
+ shape_list = [s if s is not None else -1 for s in shape_list]
+ return ops.convert_to_tensor(shape_list)
+
+
+def _graph_name(graph):
+ if isinstance(graph, function.FuncGraph):
+ return graph.name
+ return "Base"
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 31a3712de8..f86cb03995 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -114,6 +114,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:while_v2",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
]