diff options
-rw-r--r-- | tensorflow/python/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 13 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 25 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/while_v2_test.py | 252 | ||||
-rw-r--r-- | tensorflow/python/ops/while_v2.py | 573 | ||||
-rw-r--r-- | tensorflow/tools/pip_package/BUILD | 1 |
6 files changed, 885 insertions, 2 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2eeae773d3..d70e9c5798 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1998,6 +1998,29 @@ py_library( ) py_library( + name = "while_v2", + srcs = [ + "ops/while_v2.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":cond_v2_impl", + ":constant_op", + ":control_flow_util", + ":framework_ops", + ":function_def_to_graph", + ":functional_ops_gen", + ":gradients_impl", + ":list_ops", + ":tensor_shape", + ":util", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:function", + ], +) + +py_library( name = "cond_v2_impl", srcs = [ "ops/cond_v2_impl.py", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 4f1a85a274..a68c6ab3b4 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -826,7 +826,12 @@ def _get_defun_inputs_from_args(args): return nest.pack_sequence_as(args, function_inputs) -def func_graph_from_py_func(name, python_func, args, kwds, signature=None): +def func_graph_from_py_func(name, + python_func, + args, + kwds, + signature=None, + func_graph=None): """Returns a `FuncGraph` generated from `python_func`. Args: @@ -841,6 +846,8 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): `kwds` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. + func_graph: Optional. An instance of FuncGraph. If provided, we will use + this graph else a new one is built and returned. Returns: A FuncGraph. @@ -849,7 +856,9 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ - func_graph = FuncGraph(name) + if func_graph is None: + func_graph = FuncGraph(name) + assert isinstance(func_graph, FuncGraph) with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 100240a626..a048eaa69f 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3204,3 +3204,28 @@ tf_py_test( grpc_enabled = True, tags = ["no_gpu"], # TODO(b/111656070) ) + +# TODO(b/116053459): Replace with cuda_py_test. +tf_py_test( + name = "while_v2_test", + size = "medium", + srcs = ["while_v2_test.py"], + additional_deps = [ + "@absl_py//absl/testing:parameterized", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients_impl", + "//tensorflow/python:list_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tf_optimizer", + "//tensorflow/python:while_v2", + ], + grpc_enabled = True, + tags = ["no_gpu"], # TODO(b/116053459) +) diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py new file mode 100644 index 0000000000..d00e39d482 --- /dev/null +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -0,0 +1,252 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for while_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import while_v2 +from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1 +from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 +from tensorflow.python.platform import test + + +class WhileV2Test(test.TestCase, parameterized.TestCase): + + def testSingleLoopVar(self): + x = constant_op.constant(2.) + ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x]) + grad = gradients_impl.gradients(ret, [x]) + with self.test_session() as sess: + self.assertEqual(sess.run(ret), 16.) + self.assertSequenceEqual(sess.run(grad), [32.]) + + def testMultipleLoopVarsBasic(self): + x = constant_op.constant(5.) + y = constant_op.constant(3.) + + # x = 5. + # y = 3. + # while x < 45.: + # x = x * y + ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y]) + # ret = [x*y^2, y] + + # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. + grad = gradients_impl.gradients(ret, [x]) # [2*x*y] + with self.test_session() as sess: + self.assertSequenceEqual(sess.run(ret), [45., 3.]) + self.assertSequenceEqual(sess.run(grad), [9.]) + + def testMultipleLoopVars(self): + x = constant_op.constant(5.) + y = constant_op.constant(3.) + + # x = 5. + # y = 3. + # while x < 45.: + # x = x * y + # y = x + y + ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w), + [x, y]) + # ret = [y*x**2 + x*y**2, x*y + x + y] + + gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2] + gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1] + gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1] + grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] + grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] + grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] + with self.test_session() as sess: + self.assertSequenceEqual(sess.run(ret), [120., 23.]) + self.assertSequenceEqual(sess.run(gradx_0), [39.]) + self.assertSequenceEqual(sess.run(gradx_1), [4.]) + self.assertSequenceEqual(sess.run(gradx_2), [43.]) + self.assertSequenceEqual(sess.run(grady_0), [55.]) + self.assertSequenceEqual(sess.run(grady_1), [6.]) + self.assertSequenceEqual(sess.run(grady_2), [61.]) + + def testMultipleWhileLoops(self): + x = constant_op.constant(2.) + ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2 + ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4 + grad = gradients_impl.gradients(ret2, [x]) # 4x**3 + grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 + with self.test_session() as sess: + self.assertSequenceEqual(sess.run(grad), [32.]) + self.assertSequenceEqual(sess.run(grad_grad), [48.]) + + def testDoubleDerivative(self): + x = constant_op.constant(2.) + ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4 + grad = gradients_impl.gradients(ret, [x]) # 4x**3 + grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 + with self.test_session() as sess: + self.assertEqual(sess.run(ret), 16.) + self.assertSequenceEqual(sess.run(grad), [32.]) + self.assertSequenceEqual(sess.run(grad_grad), [48.]) + + def testPruning(self): + x = constant_op.constant(1) + + tensor_list = list_ops.empty_tensor_list( + element_dtype=x.dtype, element_shape=x.shape) + + def Cond(x, tl): + del tl # Unused for Cond. + return x < 5 + + def Body(x, tl): + return x + 1, list_ops.tensor_list_push_back(tl, x) + + outputs = while_loop_v1(Cond, Body, [x, tensor_list]) + + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(outputs[0]) + + def GetOptimizedGraph(): + mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) + rewriter_config = rewriter_config_pb2.RewriterConfig( + constant_folding=rewriter_config_pb2.RewriterConfig.OFF, + memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) + return tf_optimizer.OptimizeGraph(rewriter_config, mg) + + g = GetOptimizedGraph() + self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1) + + stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) + train_op.append(stack) + g = GetOptimizedGraph() + self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2) + + def testCaptureExternalTensorInCond(self): + x = constant_op.constant(2.) + y = constant_op.constant(1.) + ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x]) + grad = gradients_impl.gradients(ret, [x]) + with self.test_session() as sess: + self.assertEqual(sess.run(ret), 18.) + self.assertSequenceEqual(sess.run(grad), [9.]) + + def testCaptureExternalTensorInBody(self): + x = constant_op.constant(2.) + y = constant_op.constant(3.) + ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x]) + grad = gradients_impl.gradients(ret, [x]) + with self.test_session() as sess: + self.assertEqual(sess.run(ret), 18.) + self.assertSequenceEqual(sess.run(grad), [9.]) + + def testLoopWithTensorListPushBack(self): + x = constant_op.constant(2.) + + tensor_list = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=ScalarShape()) + + def Cond(x, tl): + del tl # Unused for Cond. + return x < 5. + + def Body(x, tl): + tl = list_ops.tensor_list_push_back(tl, x) + tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) + return x**2., tl + + ret = while_loop_v2(Cond, Body, [x, tensor_list]) + grad = gradients_impl.gradients(ret[0], x) + with self.test_session() as sess: + self.assertEqual(sess.run(ret[0]), 16.) + self.assertSequenceEqual(sess.run(grad), [32.]) + + def testDuplicateAccumulator(self): + x = constant_op.constant(2.) + + tensor_list = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=ScalarShape()) + + def Cond(x, tl): + del tl # Unused for Cond. + return x < 5. + + def Body(x, tl): + # There is an accumulator in the loop already so we should not add + # another. + tl = list_ops.tensor_list_push_back(tl, x) + return x**2., tl + + ret = while_loop_v2(Cond, Body, [x, tensor_list]) + + for op in ops.get_default_graph().get_operations(): + if op.type == "While": + while_op = op + + body_graph = while_v2._get_body_graph(while_op) + # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators] + x_input_t = body_graph.inputs[1] + accumulator_count = len( + [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) + self.assertEqual(accumulator_count, 1) + + grad = gradients_impl.gradients(ret[0], x) + with self.test_session() as sess: + self.assertEqual(sess.run(ret[0]), 16.) + self.assertSequenceEqual(sess.run(grad), [32.]) + + @parameterized.named_parameters( + ("Unknown shape", None), + ("Partially defined shape", [None]), + ("Fully defined shape", [1, 2]), + ) + def testTensorListOutputElementShape(self, shape): + self.skipTest("b/115982901") + x = constant_op.constant(2.) + y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) + ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y]) + + # Get the TensorList output of While op containing the accumulated values + # of y. + while_op = ret[0].op + body_graph = while_v2._get_body_graph(while_op) + # body_graph.inputs: [counter_arg, x_arg, y_arg, *accumulators] + y_input_t = body_graph.inputs[2] + push_back_node = [c for c in y_input_t.consumers() + if c.type == "TensorListPushBack"][0] + output_idx = body_graph.outputs.index(push_back_node.outputs[0]) + output = while_op.outputs[output_idx] + + _, val = list_ops.tensor_list_pop_back(output, + element_dtype=dtypes.float32) + self.assertEqual(val.shape, tensor_shape.TensorShape(shape)) + + +def ScalarShape(): + return ops.convert_to_tensor([], dtype=dtypes.int32) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py new file mode 100644 index 0000000000..801217fe66 --- /dev/null +++ b/tensorflow/python/ops/while_v2.py @@ -0,0 +1,573 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""while_v2 and gradient. + +This is a version of while_loop that emits a single While op, as well as the +gradient function for While ops produced by while_loop. This will eventually +replace the current tf.while_loop implementation once it reaches feature and +performance parity. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.eager import function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl as cond_v2 +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import list_ops +from tensorflow.python.util import nest + +# pylint: disable=protected-access + +# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows +# control dependencies on external nodes with at least 1 output. +# Another idea is to create const nodes outside the loop and add control edges +# to them and then pass those in as data inputs. This should probably be +# handled in the CapturingGraph itself. + + +def while_loop(cond, body, loop_vars, name=None): + """Like tf.while_loop, except emits a single While op.""" + if not name: + name = "while" + + with ops.name_scope(name) as scope: + with ops.name_scope(None): + cond_name = _get_unique_name(("%scond" % scope).replace("/", "_")) + body_name = _get_unique_name(("%sbody" % scope).replace("/", "_")) + + flattened_loop_vars = nest.flatten(loop_vars) + num_outputs = len(flattened_loop_vars) + + # Add loop counter needed for computing gradients. + flattened_loop_vars = [constant_op.constant(0., name="loop_counter") + ] + flattened_loop_vars + + # Build a `cond` wrapper that can handle the extra counter loop_var. + def wrapped_cond(unused_loop_counter, *loop_vars): + return cond(*loop_vars) + + cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond, + flattened_loop_vars, {}) + + # Add external_captures of cond to the list of loop vars. + # Note that external tensors will be treated as loop invariants, i.e., + # the value of that tensor in each iteration is the same as it was at the + # beginning of the loop execution. + flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures + + def wrapped_body(loop_counter, *args): + """Loop body augmented with counter update. + + Args: + loop_counter: Loop counter which needs to be incremented in the body. + *args: List of args + args[:num_outputs] - Args for the original loop body. + args[num_outputs:] - External captures of cond. These get passed + through as is. + + Returns: + A list of tensors the same length as args. + """ + outputs = body(*args[:num_outputs]) + if not isinstance(outputs, collections.Sequence): + outputs = [outputs] + + # Return the external_captures of cond_graph as is, i.e., treat them as + # loop invariants. + # TODO(srbs): Update lowering code to create _Enter nodes with + # is_constant=True for inputs that are directly passed to outputs. + return [loop_counter + 1] + list(outputs) + list(args[num_outputs:]) + + body_graph = function.func_graph_from_py_func(body_name, wrapped_body, + flattened_loop_vars, {}) + # Add external captures of body to the list of loop vars. + # Note that external tensors will be treated as loop invariants, i.e., + # the value of that tensor in each iteration is the same as it was at the + # beginning of the loop execution. + flattened_loop_vars = flattened_loop_vars + body_graph.external_captures + # TODO(srbs): Update lowering code to create _Enter nodes with + # is_constant=True for inputs that are directly passed to outputs. + body_graph.outputs.extend(body_graph.internal_captures) + + # Capture `external_captures` of `body_graph` in `cond_graph` so that it + # expects to receive those as arguments. + # TODO(srbs): Dedup tensors that are captured in both the cond and body. + # This logic already exists in cond_v2. + with cond_graph.as_default(): + for external_capture in body_graph.external_captures: + cond_graph.capture(external_capture) + + # Export all tensors in the loop body that may be needed for gradient + # computation. We do this by accumulating the intermediate values in + # TensorLists. + intermediate_tensors = _get_intermediates(body_graph) + + for intermediate_tensor in intermediate_tensors: + # TODO(srbs): Cache and re-use empty tensor lists. + tensor_list = list_ops.empty_tensor_list( + element_dtype=intermediate_tensor.dtype, + element_shape=_get_tensor_convertible_shape( + intermediate_tensor.shape)) + flattened_loop_vars.append(tensor_list) + with cond_graph.as_default(): + # Add a placeholder to cond_graph's inputs corresponding to the + # tensor_list. + cond_graph.capture(tensor_list) + with body_graph.as_default(): + # Push the intermediate tensor to the tensor list. This captures the + # `tensor_list` as well. + appended_tensor_list = list_ops.tensor_list_push_back( + tensor_list, + intermediate_tensor) + # Add this modified tensor list to the list of outputs. + body_graph.outputs.append(appended_tensor_list) + + outputs = gen_functional_ops._while( + flattened_loop_vars, + cond_v2._create_new_tf_function(cond_graph), + cond_v2._create_new_tf_function(body_graph), + name=scope) + + _maybe_set_lowering_attr(outputs[0].op) + + # First var is loop counter. + if num_outputs == 1: + return outputs[1] + else: + return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs]) + + +@ops.RegisterGradient("While") +def _WhileGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of a While op produced by while_loop.""" + body_graph = _get_body_graph(op) + + # Replace None gradients with zeros. This is needed because `grads` could have + # None incoming gradients for the TensorLists. If we pass None's through, the + # custom gradient of TensorListPopBack will create an EmptyTensorList inside + # the FuncGraph which is undesirable. + # TODO(b/80444525): There might be an issue with treating no gradient as zero + # gradient in certain cases. Consider replacing None gradients with Zeros + # for accumulators only. + grads = [ + g if g is not None else array_ops.zeros_like(output) + for g, output in zip(grads, op.outputs) + ] + + body_grad_graph, args = _create_grad_func( + body_graph, grads, + _get_unique_name("%s_grad" % body_graph.name), op) + + intermediate_tensors = _get_intermediates(body_grad_graph) + + for intermediate_tensor in intermediate_tensors: + tensor_list = list_ops.empty_tensor_list( + element_dtype=intermediate_tensor.dtype, + element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape)) + with body_grad_graph.as_default(): + tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) + # Push the intermediate tensor to the tensor list. + appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, + intermediate_tensor) + # Add this modified tensor list to the list of outputs. + body_grad_graph.outputs.append(appended_tensor_list) + + def grad_cond(counter, max_iters, *unused_args): + return counter < max_iters + + loop_vars = args + body_grad_graph.external_captures + cond_grad_graph = function.func_graph_from_py_func( + _get_unique_name("%s_grad_cond" % op.name), + grad_cond, loop_vars, {}) + + assert len(loop_vars) == len(body_grad_graph.inputs) + assert len(loop_vars) == len(body_grad_graph.outputs) + assert len(loop_vars) == len(cond_grad_graph.inputs) + + outputs = gen_functional_ops._while( + loop_vars, + cond_v2._create_new_tf_function(cond_grad_graph), + cond_v2._create_new_tf_function(body_grad_graph), + name=_get_unique_name("%s_grad" % op.name)) + + _maybe_set_lowering_attr(outputs[0].op) + + # outputs[0] is the loop counter. + # outputs[1] is the total number of loop iterations. + return outputs[2:2 + len(op.inputs)] + + +# TODO(srbs): Pull this into common utils for cond_v2 and while_v2. +def _get_body_graph(while_op): + """Returns `FuncGraph` for the while body. + + Args: + while_op: The While Operation. + + Returns: + `FuncGraph` for the while body. + """ + extra_inputs = list(while_op.inputs) + input_shapes = [t.shape for t in extra_inputs] + func_name = while_op.get_attr("body").name + fdef = while_op.graph._get_function(func_name).definition + func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + func_graph._while = while_op + return func_graph + + +def _create_grad_func(func_graph, grads, name, while_op): + """Builds and returns the gradient FuncGraph of `func_graph` and its args. + + The returned grad_func_graph must be called with the returned + args + grad_func_graph.captures. + + Args: + func_graph: FuncGraph for the forward body function. + grads: The incoming grads for `func_graph`'s outputs. + name: Name of the returned gradient function. + while_op: The forward While op. + + Returns: + 2-tuple of (grad_func_graph, args). + """ + assert len(func_graph.outputs) == len(grads) + + loop_counter = constant_op.constant(0.) + # TODO(srbs): For nested while loops will need to lookup this value from + # the accumulator of the enclosing while loop. For now use as is assuming + # there is no nesting. + num_iters_t = while_op.outputs[0] + + args = [loop_counter, num_iters_t] + grads + + # Note: The returned function does not have `args` in the list of + # `external_captures`. + grad_func_graph = function.func_graph_from_py_func( + name, + lambda *args: _grad_fn(func_graph, args), + args, {}, + func_graph=_WhileBodyGradFuncGraph(name, func_graph)) + + # Add the popped accumulators to the list of outputs. + for internal_capture in grad_func_graph.internal_captures: + grad_func_graph.outputs.append( + grad_func_graph.popped_tensor_lists[internal_capture]) + + return grad_func_graph, args + + +def _grad_fn(func_graph, args): + """Computes the gradient of `func_graph` in the current graph. + + This function builds the gradient graph of the corresponding forward-pass + `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. + + Args: + func_graph: function.FuncGraph. The corresponding forward-pass function. + args: The input arguments. args[0] - Loop counter args[1] - Total number of + iterations. + args[2:] - Incoming gradients for `func_graph.outputs`. + + Returns: + The output gradient Tensors. + """ + xs = func_graph.inputs + ys = func_graph.outputs + grad_ys = args[2:] + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _resolve_grad_inputs. + # TODO(srbs): Mark GradientsHelper as public? + grad_outs = gradients_impl._GradientsHelper( + ys, xs, grad_ys=grad_ys, src_graph=func_graph) + + assert all([g is not None for g in grad_outs]) + counter = args[0] + total_iters = args[1] + return [counter + 1, total_iters] + grad_outs + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that should be accumulated.""" + # We currently accumulate output tensors of most ops in the function and rely + # on the pruning pass to get rid of the unused accumulators at runtime. + # However, this can bloat the GraphDef and make debugging harder so we perform + # some optimizations. + # + # Optimization we currently perform: + # 1. We do not accumulate tensors which already have an accumulator + # in the loop body. + # 2. We do not accumulate outputs of Identity nodes. When building the + # FuncGraph, we add an Identity node for each output (see + # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs + # of all these nodes bloats the GraphDef quite a bit so we remove those. + # Since the gradient of an Identity node does not rely on its forward op's + # input this is safe to do. + # + # Other possible optimizations: + # 1. Only accumulate tensors that will be required by the backward pass. + # This will require running the gradient pass and hence would increase the + # graph building time for the forward pass. + # 2. Do not accumulate Const nodes created inside the loop body. + # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants. + # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer + # since it requires knowledge of the while loop semantics. If so, consider + # doing those here. + intermediates = [] + + for op in func_graph.get_operations(): + if op.type == "Identity": + continue + for o in op.outputs: + if (o != func_graph.inputs[0] and # Loop counter. + _get_accumulator(o) is None): # Has existing accumulator. + intermediates.append(o) + return intermediates + + +def _get_accumulator(tensor): + r"""Returns TensorList if any containing accumulated values of tensor. + + We try to find a pattern of the form: + + input_tl tensor + \ / + (TensorListPushBack) + | + output_tl + + which satisfies the following conditions: + + 1. input_tl must be in tensor.graph.inputs. + 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs. + 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t). + + output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is + returned if such a pattern is found else None is returned. + + Args: + tensor: The Tensor to be accumulated. + + Returns: + A variant tensor in the same graph as `tensor` or None if no accumulator is + found. + """ + assert isinstance(tensor.graph, function.FuncGraph) + + def get_func_graph_output(t): + """Returns t or Identity(t) whichever exists in graph outputs else None.""" + if t in tensor.graph.outputs: + return t + # tf.defun adds an Identity for each output, check whether that is the case. + identity_op = t.consumers()[0] + if (identity_op.type == "Identity" and + identity_op.outputs[0] in tensor.graph.outputs): + return identity_op.outputs[0] + return None + + for consumer in tensor.consumers(): + # Find the consumer that is a TensorListPushBack node whose TensorList input + # is in the list of function inputs. + if (consumer.type != "TensorListPushBack" or + consumer.inputs[0] not in tensor.graph.inputs): + continue + + output = get_func_graph_output(consumer.outputs[0]) + if output is None: + # The TensorList output of `consumer` is not in the list of function + # outputs. + continue + + accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0]) + accum_output_idx = tensor.graph.outputs.index(output) + if accum_input_idx == accum_output_idx: + return output + return None + + +# TODO(srbs): Add to common utils for cond_v2 and while_v2. +def _get_unique_name(name): + """Returns a name that is unique in the root graph of `func_graph`. + + Args: + name: String to uniquify. + + Returns: + A string. + """ + with ops.init_scope(): + return ops.get_default_graph().unique_name(name) + + +class _WhileBodyGradFuncGraph(function.FuncGraph): + """FuncGraph for the gradient function of the body of a While op. + + Contains the logic for capturing the tensors from the body of the forward + While op which is as follows: + 1. Find the accumulator for that tensor. + 2. Capture the forward While op output tensor corresponding to the + accumulator in this FuncGraph. + 3. Pop a value from the captured placeholder and use it as the captured value + for the forward pass tensor. + + This only allows capturing tensors in the forward graph. A ValueError is + raised if an attempt is made to capture a tensor not in the forward graph. + To manually capture capture a tensor that is not in the forward graph, call + `capture` with `whitelisted=True`. + + Note: The `captures` dict does not contain the forward tensor since it is not + directly captured. It contains the accumulator corresponding to this forward + tensor. + + Attributes: + popped_tensor_lists: Dict from the captured accumulator placeholder to the + TensorList obtained after popping the intermediate tensor from it. The + values of this dict need to be added to the list of outputs. + """ + + def __init__(self, name, forward_graph): + super(_WhileBodyGradFuncGraph, self).__init__(name) + self.popped_tensor_lists = {} + # FuncGraph for the body of the forward While op. + self._forward_graph = forward_graph + # Dict from forward intermediate tensor to the corresponding "popped" tensor + # in this graph. + self._indirect_captures = {} + # Dict from forward graph tensor to the While op output corresponding to its + # accumulator. + self._tensor_to_accumulator = {} + + def capture(self, tensor, name=None, whitelisted=False): + """Selectively captures external tensors. + + If `whitelisted` is False only allows capturing tensors in the + `_forward_graph`. + + Args: + tensor: Tensor. May be from this FuncGraph or a different graph. + name: Optional name if a placeholder is created. + whitelisted: If False (default), only allows capturing tensors from the + forward graph. + + Returns: + The placeholder in this graph for the tensor. + + Raises: + ValueError: If attempting to capture an external tensor not in the forward + graph with `whitelisted` set to False. + """ + if (not whitelisted and tensor.graph is not self and + tensor.graph != self._forward_graph): + raise ValueError("Attempting to capture tensor", str(tensor), + " which is not in the forward graph but in ", + _graph_name(tensor.graph), ".") + return super(_WhileBodyGradFuncGraph, self).capture(tensor, name) + + def _capture_helper(self, tensor, name): + if tensor.graph is not self._forward_graph: + return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) + + captured_tensor = self._indirect_captures.get(tensor) + if captured_tensor is not None: + # For GradientTape housekeeping. + assert self._tensor_to_accumulator[tensor] in self.captures + super(_WhileBodyGradFuncGraph, self)._capture_helper( + self._tensor_to_accumulator[tensor], name) + return captured_tensor + + assert tensor not in self._tensor_to_accumulator + + accumulator = None + + # Find the TensorList that was used to accumulate the tensors of this + # intermediate tensor. + accumulator = _get_accumulator(tensor) + if accumulator is None: + raise ValueError("Reference to un-accumulated intermediate tensor: ", + tensor.name) + assert accumulator.graph == self._forward_graph + # Get the While op output corresponding to the accumulator. + accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs + .index(accumulator)] + + assert accumulator.graph == self._forward_graph.outer_graph + self._tensor_to_accumulator[tensor] = accumulator + + # Capture the `accumulator`. + accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper( + accumulator, name) + new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( + accumulator_ph, element_dtype=tensor.dtype) + self._indirect_captures[tensor] = captured_tensor + self.popped_tensor_lists[accumulator_ph] = new_tensor_list + return captured_tensor + + +# TODO(srbs): Move to common utils for cond_v2 and while_v2. +def _maybe_set_lowering_attr(op): + """Sets the flag to enable lowering on the `While` op if necessary. + + Lowering allows while_v2 to avoid some of the limitations of Functions, + allowing users to specify devices & colocation inside of while_v2 + branches, and enabling non-strict evaluation & partial pruning of while_v2 + branches. This brings while_v2 closer to feature parity with + tf.while_loop. + + However, we do not lower `While` in the XLA context because it is easier + for XLA to apply its own optimizations when dealing with un-lowered + `While` operators than with low-level control flow primitives. + + Args: + op: The While op. + """ + if not control_flow_util.IsInXLAContext(op): + # pylint: disable=protected-access + op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + + +def _get_tensor_convertible_shape(shape): + assert isinstance(shape, tensor_shape.TensorShape) + if shape.is_fully_defined(): + return shape + if not shape: # Unknown shape. + return -1 + # Partially defined shape. + shape_list = shape.as_list() + shape_list = [s if s is not None else -1 for s in shape_list] + return ops.convert_to_tensor(shape_list) + + +def _graph_name(graph): + if isinstance(graph, function.FuncGraph): + return graph.name + return "Base" + + +# pylint: enable=protected-access diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 31a3712de8..f86cb03995 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -114,6 +114,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/tools:tools_pip", "//tensorflow/python/tools/api/generator:create_python_api", "//tensorflow/python:test_ops", + "//tensorflow/python:while_v2", "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ] |