aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 12:33:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 12:38:19 -0700
commitdd6d7c5c586b541b9d4793b7578feadd0c2da8f6 (patch)
treec69ca553da1100b948bd81fc85784f2302b0adbf /tensorflow/compiler/tf2xla
parent656b3e9c847c187ff011982fe806f9f48853ed1a (diff)
Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc6
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc26
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc23
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc198
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc7
7 files changed, 115 insertions, 155 deletions
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index bc2e640559..82e9eef005 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
TF_ASSIGN_OR_RETURN(auto literal,
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
+ LiteralToHostTensor(literal, arg.type, &arg.constant_value));
} else {
arg.kind = XlaCompiler::Argument::kParameter;
}
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 22a45b2a11..3d81ae9eb8 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
- xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
+ xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index c267848524..804671fbc7 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
+ literal = xla::LiteralUtil::CreateR0<uint8>(value);
break;
case xla::U32:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
+ literal = xla::LiteralUtil::CreateR0<uint32>(value);
break;
case xla::U64:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
+ literal = xla::LiteralUtil::CreateR0<uint64>(value);
break;
case xla::S8:
- literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
+ literal = xla::LiteralUtil::CreateR0<int8>(value);
break;
case xla::S32:
- literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
+ literal = xla::LiteralUtil::CreateR0<int32>(value);
break;
case xla::S64:
- literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
+ literal = xla::LiteralUtil::CreateR0<int64>(value);
break;
case xla::F32:
- literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
+ literal = xla::LiteralUtil::CreateR0<float>(value);
break;
case xla::F64:
- literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
+ literal = xla::LiteralUtil::CreateR0<double>(value);
break;
case xla::C64:
- literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
+ literal = xla::LiteralUtil::CreateR0<complex64>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
- literal = std::move(
- *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
break;
case xla::F16:
- literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
- static_cast<xla::half>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index 7dc16b5a46..ed452bceeb 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
{
std::vector<int64> int64_values = {1, 2, 3};
- std::unique_ptr<xla::Literal> int64_values_literal =
+ xla::Literal int64_values_literal =
xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
- LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
+ LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
.error_message());
- EXPECT_EQ(
- "Cannot convert literal of type S64 to tensor of type qint32",
- LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
- .error_message());
EXPECT_TRUE(
- LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
- .ok());
+ LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
test::ExpectTensorEqual<int64>(host_tensor,
test::AsTensor<int64>(int64_values));
}
@@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// Repeat tests with int32.
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
- std::unique_ptr<xla::Literal> int32_values_literal =
+ xla::Literal int32_values_literal =
xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
- .ok());
+ LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
test::ExpectTensorEqual<int32>(host_tensor,
test::AsTensor<int32>(int32_values));
EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
+ LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
.ok());
std::vector<qint32> qint32_values = {10, 11};
test::ExpectTensorEqual<qint32>(host_tensor,
test::AsTensor<qint32>(qint32_values));
EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
- LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
+ LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
.error_message());
}
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 56f7045a98..ab26d939cc 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
// Set up arguments.
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
- auto x_global_or = client->TransferToServer(*x_literal);
- auto y_global_or = client->TransferToServer(*y_literal);
+ auto x_global_or = client->TransferToServer(x_literal);
+ auto y_global_or = client->TransferToServer(y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
@@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
auto result_or =
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
- std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
- EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
+ xla::Literal result = std::move(result_or.ValueOrDie());
+ EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
config.mutable_feed(0)->mutable_id()->set_output_index(
123); /* invalid output_index */
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 40ce9fb41c..70efa7781d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
@@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler doesn't reorder the parameters.
@@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(
- xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
{
@@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR0<int32>(7);
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
@@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> input_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> input_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> input =
- xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*input).ConsumeValueOrDie();
+ client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> output_read =
- xla::LiteralUtil::CreateR0<int32>(42);
- std::unique_ptr<xla::Literal> output_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> output_grad1 =
- xla::LiteralUtil::CreateR1<int32>({0, 1});
- std::unique_ptr<xla::Literal> output_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
- {output_base.get(), output_grad1.get(), output_grad2.get()});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
+ xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
+ xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal output_resource =
+ xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({5, 144});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a simple graph that reads and writes a variable.
@@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
@@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index d1534e9a15..d10a504da0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
}
- xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
- compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
- &layout);
+ xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
+ constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
" as a compile-time constant.\nError: ",
computed.status().error_message());
}
- *constant_literal = std::move(*computed.ValueOrDie());
+ *constant_literal = std::move(computed).ValueOrDie();
return Status::OK();
}