aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 12:33:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 12:38:19 -0700
commitdd6d7c5c586b541b9d4793b7578feadd0c2da8f6 (patch)
treec69ca553da1100b948bd81fc85784f2302b0adbf /tensorflow/compiler/tf2xla/xla_compiler_test.cc
parent656b3e9c847c187ff011982fe806f9f48853ed1a (diff)
Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc198
1 files changed, 81 insertions, 117 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 40ce9fb41c..70efa7781d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
@@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler doesn't reorder the parameters.
@@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(
- xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
{
@@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR0<int32>(7);
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
@@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> input_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> input_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> input =
- xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*input).ConsumeValueOrDie();
+ client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> output_read =
- xla::LiteralUtil::CreateR0<int32>(42);
- std::unique_ptr<xla::Literal> output_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> output_grad1 =
- xla::LiteralUtil::CreateR1<int32>({0, 1});
- std::unique_ptr<xla::Literal> output_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
- {output_base.get(), output_grad1.get(), output_grad2.get()});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
+ xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
+ xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal output_resource =
+ xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({5, 144});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a simple graph that reads and writes a variable.
@@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
@@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.