diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-10 12:33:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 12:38:19 -0700 |
commit | dd6d7c5c586b541b9d4793b7578feadd0c2da8f6 (patch) | |
tree | c69ca553da1100b948bd81fc85784f2302b0adbf /tensorflow/compiler/tf2xla | |
parent | 656b3e9c847c187ff011982fe806f9f48853ed1a (diff) |
Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/graph_compiler.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/util.cc | 26 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/literal_util_test.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_test.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 198 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 7 |
7 files changed, 115 insertions, 155 deletions
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index bc2e640559..82e9eef005 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(auto literal, client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + LiteralToHostTensor(literal, arg.type, &arg.constant_value)); } else { arg.kind = XlaCompiler::Argument::kParameter; } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 22a45b2a11..3d81ae9eb8 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector<xla::XlaOp> args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c267848524..804671fbc7 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value)); + literal = xla::LiteralUtil::CreateR0<uint8>(value); break; case xla::U32: - literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value)); + literal = xla::LiteralUtil::CreateR0<uint32>(value); break; case xla::U64: - literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value)); + literal = xla::LiteralUtil::CreateR0<uint64>(value); break; case xla::S8: - literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value)); + literal = xla::LiteralUtil::CreateR0<int8>(value); break; case xla::S32: - literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value)); + literal = xla::LiteralUtil::CreateR0<int32>(value); break; case xla::S64: - literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value)); + literal = xla::LiteralUtil::CreateR0<int64>(value); break; case xla::F32: - literal = std::move(*xla::LiteralUtil::CreateR0<float>(value)); + literal = xla::LiteralUtil::CreateR0<float>(value); break; case xla::F64: - literal = std::move(*xla::LiteralUtil::CreateR0<double>(value)); + literal = xla::LiteralUtil::CreateR0<double>(value); break; case xla::C64: - literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value)); + literal = xla::LiteralUtil::CreateR0<complex64>(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: - literal = std::move( - *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value))); + literal = + xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)); break; case xla::F16: - literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>( - static_cast<xla::half>(value))); + literal = + xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 7dc16b5a46..ed452bceeb 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. { std::vector<int64> int64_values = {1, 2, 3}; - std::unique_ptr<xla::Literal> int64_values_literal = + xla::Literal int64_values_literal = xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) .error_message()); - EXPECT_EQ( - "Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); EXPECT_TRUE( - LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) - .ok()); + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); test::ExpectTensorEqual<int64>(host_tensor, test::AsTensor<int64>(int64_values)); } @@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) { // Repeat tests with int32. Tensor host_tensor; std::vector<int32> int32_values = {10, 11}; - std::unique_ptr<xla::Literal> int32_values_literal = + xla::Literal int32_values_literal = xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values)); EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) - .ok()); + LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok()); test::ExpectTensorEqual<int32>(host_tensor, test::AsTensor<int32>(int32_values)); EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) + LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor) .ok()); std::vector<qint32> qint32_values = {10, 11}; test::ExpectTensorEqual<qint32>(host_tensor, test::AsTensor<qint32>(qint32_values)); EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) + LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor) .error_message()); } } diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 56f7045a98..ab26d939cc 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) { // Set up arguments. auto x_literal = xla::LiteralUtil::CreateR0<int32>(10); auto y_literal = xla::LiteralUtil::CreateR0<int32>(32); - auto x_global_or = client->TransferToServer(*x_literal); - auto y_global_or = client->TransferToServer(*y_literal); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); TF_EXPECT_OK(y_global_or.status()); std::unique_ptr<xla::GlobalData> x_global = @@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) { auto result_or = client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); - std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 40ce9fb41c..70efa7781d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({4, 143}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } // Tests that the compiler doesn't reorder the parameters. @@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({-7, -42}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR0<int32>(7); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({-7, -42}); - std::unique_ptr<xla::Literal> expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> input_base = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> input_grad2 = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); - std::unique_ptr<xla::Literal> input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> output_read = - xla::LiteralUtil::CreateR0<int32>(42); - std::unique_ptr<xla::Literal> output_base = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> output_grad1 = - xla::LiteralUtil::CreateR1<int32>({0, 1}); - std::unique_ptr<xla::Literal> output_grad2 = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); - std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - std::unique_ptr<xla::Literal> param0_literal = - xla::LiteralUtil::CreateR1<int32>({7, 42}); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({5, 144}); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({4, 143}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a simple graph that reads and writes a variable. @@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR1<int32>({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ReturnResourceHandle) { @@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}}); - std::unique_ptr<xla::Literal> param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}}); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3}); - std::unique_ptr<xla::Literal> param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402}); - std::unique_ptr<xla::Literal> expected1 = - xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); - std::unique_ptr<xla::Literal> expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index d1534e9a15..d10a504da0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr<std::unique_ptr<xla::Literal>> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } |