aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 6f76816a86..2fb93be01d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -228,6 +228,58 @@ TEST_F(XlaCompilerTest, Simple) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests compilation of a graph where the _Retval node is not necessarily last
+// amongst the graph nodes in construction order, and always_return_tuple is
+// false. Regression test for bug where the wrong value was returned.
+TEST_F(XlaCompilerTest, OutOfOrderGraph) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
+ // The _Retval node is not last in construction order.
+ auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
+ auto c = ops::Add(scope.WithOpName("C"), a, b);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+ args[1].kind = XlaCompiler::Argument::kParameter;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.always_return_tuple = false;
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
+ args, &result));
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> param0_literal =
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+}
+
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
// Builds a graph that adds reshapes a tensor, but with the shape not
// statically known.