diff options
author | 2018-09-10 12:33:49 -0700 | |
---|---|---|
committer | 2018-09-10 12:38:19 -0700 | |
commit | dd6d7c5c586b541b9d4793b7578feadd0c2da8f6 (patch) | |
tree | c69ca553da1100b948bd81fc85784f2302b0adbf /tensorflow/compiler/tf2xla/xla_compiler_test.cc | |
parent | 656b3e9c847c187ff011982fe806f9f48853ed1a (diff) |
Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler_test.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 198 |
1 files changed, 81 insertions, 117 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 40ce9fb41c..70efa7781d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({4, 143}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } // Tests that the compiler doesn't reorder the parameters. @@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({-7, -42}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR0<int32>(7); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({-7, -42}); - std::unique_ptr<xla::Literal> expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> input_base = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> input_grad2 = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); - std::unique_ptr<xla::Literal> input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> output_read = - xla::LiteralUtil::CreateR0<int32>(42); - std::unique_ptr<xla::Literal> output_base = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> output_grad1 = - xla::LiteralUtil::CreateR1<int32>({0, 1}); - std::unique_ptr<xla::Literal> output_grad2 = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); - std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({5, 144}); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({4, 143}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a simple graph that reads and writes a variable. @@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ReturnResourceHandle) { @@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}}); - std::unique_ptr<xla::Literal> param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}}); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3}); - std::unique_ptr<xla::Literal> param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402}); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. |