diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 6f76816a86..2fb93be01d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -228,6 +228,58 @@ TEST_F(XlaCompilerTest, Simple) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests compilation of a graph where the _Retval node is not necessarily last +// amongst the graph nodes in construction order, and always_return_tuple is +// false. Regression test for bug where the wrong value was returned. +TEST_F(XlaCompilerTest, OutOfOrderGraph) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + // The _Retval node is not last in construction order. + auto d = ops::_Retval(scope.WithOpName("D"), a, 0); + auto c = ops::Add(scope.WithOpName("C"), a, b); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> param0_literal = + xla::LiteralUtil::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> param1_literal = + xla::LiteralUtil::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr<xla::GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. |