diff options
author | 2017-09-19 19:08:19 -0700 | |
---|---|---|
committer | 2017-09-19 19:12:43 -0700 | |
commit | 1f20a786d69c4b91a4015fe3f4df8c23bd345f40 (patch) | |
tree | 9175a24a490a21587cb899b4dda98f11f83f948c /tensorflow/compiler/tf2xla/xla_compiler_test.cc | |
parent | 5ce3523bcc844217b47e7f862c1bed894cbaa34e (diff) |
[TF:XLA] Add support for reading and writing TensorArray gradients in a while loop.
Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple.
PiperOrigin-RevId: 169338418
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aa8df80d34..f516dd867a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -349,5 +350,145 @@ TEST_F(XlaCompilerTest, ResourceManager) { resource->Unref(); } +// Tests a computation that receives a TensorArray resource as input and +// updates it. +TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const<float>(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); + auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2"); + auto index = ops::Const<int32>(scope, 1); + auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index, + grad2.flow_out); + auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad2"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + ASSERT_EQ(1, result.resource_updates.size()); + const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; + EXPECT_EQ(0, update.input_index); + EXPECT_EQ(DT_INT32, update.type); + EXPECT_EQ((std::set<string>{"grad1", "grad2"}), + update.tensor_array_gradients_accessed); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> input_base = + xla::Literal::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> input_grad2 = + xla::Literal::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::Literal> input = + xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*input).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_->Execute(*result.computation, {param0_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42); + std::unique_ptr<xla::Literal> output_base = + xla::Literal::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> output_grad1 = + xla::Literal::CreateR1<int32>({0, 1}); + std::unique_ptr<xla::Literal> output_grad2 = + xla::Literal::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple( + {output_base.get(), output_grad1.get(), output_grad2.get()}); + std::unique_ptr<xla::Literal> expected_literal = + xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const<float>(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); + auto index = ops::Const<int32>(scope, 1); + auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad1"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + EXPECT_EQ(0, result.resource_updates.size()); +} + +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const<float>(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2"); + auto index = ops::Const<int32>(scope, 1); + auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad1"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + EXPECT_EQ(1, result.resource_updates.size()); +} + } // namespace } // namespace tensorflow |