aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc129
1 files changed, 91 insertions, 38 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 613230452b..2fb93be01d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -222,12 +222,64 @@ TEST_F(XlaCompilerTest, Simple) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({4, 143});
+ xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests compilation of a graph where the _Retval node is not necessarily last
+// amongst the graph nodes in construction order, and always_return_tuple is
+// false. Regression test for bug where the wrong value was returned.
+TEST_F(XlaCompilerTest, OutOfOrderGraph) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
+ // The _Retval node is not last in construction order.
+ auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
+ auto c = ops::Add(scope.WithOpName("C"), a, b);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+ args[1].kind = XlaCompiler::Argument::kParameter;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.always_return_tuple = false;
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
+ args, &result));
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> param0_literal =
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+}
+
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
// Builds a graph that adds reshapes a tensor, but with the shape not
// statically known.
@@ -306,7 +358,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -317,9 +369,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({-7, -42});
+ xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -341,7 +393,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -351,11 +403,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
+ std::unique_ptr<xla::Literal> expected0 =
+ xla::LiteralUtil::CreateR0<int32>(7);
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({-7, -42});
+ xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
}
}
@@ -569,11 +622,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> input_base =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> input_grad2 =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::Literal> input =
- xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*input).ConsumeValueOrDie();
@@ -583,17 +636,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
+ std::unique_ptr<xla::Literal> output_read =
+ xla::LiteralUtil::CreateR0<int32>(42);
std::unique_ptr<xla::Literal> output_base =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> output_grad1 =
- xla::Literal::CreateR1<int32>({0, 1});
+ xla::LiteralUtil::CreateR1<int32>({0, 1});
std::unique_ptr<xla::Literal> output_grad2 =
- xla::Literal::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
{output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
+ xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -796,9 +850,9 @@ TEST_F(XlaCompilerTest, Variables) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -812,11 +866,11 @@ TEST_F(XlaCompilerTest, Variables) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({5, 144});
+ xla::LiteralUtil::CreateR1<int32>({5, 144});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({4, 143});
+ xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -884,9 +938,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
+ xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -900,11 +954,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
+ xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -953,9 +1007,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({4, 55, 1, -3});
+ xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -969,11 +1023,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({27, 67, 35, 402});
+ xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -1021,8 +1075,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
}