aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler_test.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-19 19:08:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 19:12:43 -0700
commit1f20a786d69c4b91a4015fe3f4df8c23bd345f40 (patch)
tree9175a24a490a21587cb899b4dda98f11f83f948c /tensorflow/compiler/tf2xla/xla_compiler_test.cc
parent5ce3523bcc844217b47e7f862c1bed894cbaa34e (diff)
[TF:XLA] Add support for reading and writing TensorArray gradients in a while loop.
Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple. PiperOrigin-RevId: 169338418
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc141
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index aa8df80d34..f516dd867a 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -349,5 +350,145 @@ TEST_F(XlaCompilerTest, ResourceManager) {
resource->Unref();
}
+// Tests a computation that receives a TensorArray resource as input and
+// updates it.
+TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
+ auto flow = ops::Const<float>(scope, {});
+ auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
+ auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
+ auto index = ops::Const<int32>(scope, 1);
+ auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
+ grad2.flow_out);
+ auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
+ auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kTensorArray;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::S32, {2}),
+ xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].tensor_array_size = 2;
+ args[0].tensor_array_gradients = {"grad2"};
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ ASSERT_EQ(1, result.resource_updates.size());
+ const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
+ EXPECT_EQ(0, update.input_index);
+ EXPECT_EQ(DT_INT32, update.type);
+ EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
+ update.tensor_array_gradients_accessed);
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> input_base =
+ xla::Literal::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> input_grad2 =
+ xla::Literal::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> input =
+ xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*input).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_->Execute(*result.computation, {param0_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
+ std::unique_ptr<xla::Literal> output_base =
+ xla::Literal::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> output_grad1 =
+ xla::Literal::CreateR1<int32>({0, 1});
+ std::unique_ptr<xla::Literal> output_grad2 =
+ xla::Literal::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
+ {output_base.get(), output_grad1.get(), output_grad2.get()});
+ std::unique_ptr<xla::Literal> expected_literal =
+ xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
+ xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+}
+
+// Tests compilation and execution of a graph that adds two tensors.
+TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
+ auto flow = ops::Const<float>(scope, {});
+ auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
+ auto index = ops::Const<int32>(scope, 1);
+ auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
+ auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kTensorArray;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::S32, {2}),
+ xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].tensor_array_size = 2;
+ args[0].tensor_array_gradients = {"grad1"};
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ EXPECT_EQ(0, result.resource_updates.size());
+}
+
+// Tests compilation and execution of a graph that adds two tensors.
+TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
+ auto flow = ops::Const<float>(scope, {});
+ auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
+ auto index = ops::Const<int32>(scope, 1);
+ auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
+ auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kTensorArray;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::S32, {2}),
+ xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].tensor_array_size = 2;
+ args[0].tensor_array_gradients = {"grad1"};
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ EXPECT_EQ(1, result.resource_updates.size());
+}
+
} // namespace
} // namespace tensorflow