diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 129 |
1 files changed, 91 insertions, 38 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 613230452b..2fb93be01d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({-3, 101}); + xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -222,12 +222,64 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({4, 143}); + xla::LiteralUtil::CreateR1<int32>({4, 143}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests compilation of a graph where the _Retval node is not necessarily last +// amongst the graph nodes in construction order, and always_return_tuple is +// false. Regression test for bug where the wrong value was returned. +TEST_F(XlaCompilerTest, OutOfOrderGraph) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + // The _Retval node is not last in construction order. + auto d = ops::_Retval(scope.WithOpName("D"), a, 0); + auto c = ops::Add(scope.WithOpName("C"), a, b); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> param0_literal = + xla::LiteralUtil::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> param1_literal = + xla::LiteralUtil::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr<xla::GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. @@ -306,7 +358,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -317,9 +369,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({-7, -42}); + xla::LiteralUtil::CreateR1<int32>({-7, -42}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE( xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -341,7 +393,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -351,11 +403,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7); + std::unique_ptr<xla::Literal> expected0 = + xla::LiteralUtil::CreateR0<int32>(7); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({-7, -42}); + xla::LiteralUtil::CreateR1<int32>({-7, -42}); std::unique_ptr<xla::Literal> expected = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } @@ -569,11 +622,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> input_base = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> input_grad2 = - xla::Literal::CreateR1<int32>({-3, 101}); + xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::Literal> input = - xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*input).ConsumeValueOrDie(); @@ -583,17 +636,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42); + std::unique_ptr<xla::Literal> output_read = + xla::LiteralUtil::CreateR0<int32>(42); std::unique_ptr<xla::Literal> output_base = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> output_grad1 = - xla::Literal::CreateR1<int32>({0, 1}); + xla::LiteralUtil::CreateR1<int32>({0, 1}); std::unique_ptr<xla::Literal> output_grad2 = - xla::Literal::CreateR1<int32>({-3, 101}); - std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple( + xla::LiteralUtil::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple( {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -796,9 +850,9 @@ TEST_F(XlaCompilerTest, Variables) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({-3, 101}); + xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -812,11 +866,11 @@ TEST_F(XlaCompilerTest, Variables) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({5, 144}); + xla::LiteralUtil::CreateR1<int32>({5, 144}); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({4, 143}); + xla::LiteralUtil::CreateR1<int32>({4, 143}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -884,9 +938,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -900,11 +954,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}}); + xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}}); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -953,9 +1007,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -969,11 +1023,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({27, 67, 35, 402}); + xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402}); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -1021,8 +1075,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); } |