From 1f20a786d69c4b91a4015fe3f4df8c23bd345f40 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Sep 2017 19:08:19 -0700 Subject: [TF:XLA] Add support for reading and writing TensorArray gradients in a while loop. Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple. PiperOrigin-RevId: 169338418 --- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 141 ++++++++++++++++++++++++ 1 file changed, 141 insertions(+) (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc') diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aa8df80d34..f516dd867a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -349,5 +350,145 @@ TEST_F(XlaCompilerTest, ResourceManager) { resource->Unref(); } +// Tests a computation that receives a TensorArray resource as input and +// updates it. +TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); + auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2"); + auto index = ops::Const(scope, 1); + auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index, + grad2.flow_out); + auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad2"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + ASSERT_EQ(1, result.resource_updates.size()); + const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; + EXPECT_EQ(0, update.input_index); + EXPECT_EQ(DT_INT32, update.type); + EXPECT_EQ((std::set{"grad1", "grad2"}), + update.tensor_array_gradients_accessed); + + // Tests that the generated computation works. + std::unique_ptr input_base = + xla::Literal::CreateR1({7, 42}); + std::unique_ptr input_grad2 = + xla::Literal::CreateR1({-3, 101}); + std::unique_ptr input = + xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + std::unique_ptr param0_data = + client_->TransferToServer(*input).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_->Execute(*result.computation, {param0_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr output_read = xla::Literal::CreateR0(42); + std::unique_ptr output_base = + xla::Literal::CreateR1({7, 42}); + std::unique_ptr output_grad1 = + xla::Literal::CreateR1({0, 1}); + std::unique_ptr output_grad2 = + xla::Literal::CreateR1({-3, 101}); + std::unique_ptr output_resource = xla::Literal::MakeTuple( + {output_base.get(), output_grad1.get(), output_grad2.get()}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); + auto index = ops::Const(scope, 1); + auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad1"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + EXPECT_EQ(0, result.resource_updates.size()); +} + +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2"); + auto index = ops::Const(scope, 1); + auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad1"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + EXPECT_EQ(1, result.resource_updates.size()); +} + } // namespace } // namespace tensorflow -- cgit v1.2.3