aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 12:33:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 12:38:19 -0700
commitdd6d7c5c586b541b9d4793b7578feadd0c2da8f6 (patch)
treec69ca553da1100b948bd81fc85784f2302b0adbf
parent656b3e9c847c187ff011982fe806f9f48853ed1a (diff)
Global de-std::unique_ptr cleanup for xla::Literal.
PiperOrigin-RevId: 212313258
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc6
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc26
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc23
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc198
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc7
-rw-r--r--tensorflow/compiler/xla/client/client.cc12
-rw-r--r--tensorflow/compiler/xla/client/client.h10
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc4
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc20
-rw-r--r--tensorflow/compiler/xla/client/local_client.h10
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc2
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h38
-rw-r--r--tensorflow/compiler/xla/literal.cc133
-rw-r--r--tensorflow/compiler/xla/literal.h53
-rw-r--r--tensorflow/compiler/xla/literal_test.cc910
-rw-r--r--tensorflow/compiler/xla/literal_util.cc273
-rw-r--r--tensorflow/compiler/xla/literal_util.h228
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc10
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc20
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h8
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i18
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc7
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h2
-rw-r--r--tensorflow/compiler/xla/reference_util.cc28
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc50
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc19
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc12
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc237
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h57
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc484
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h195
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc54
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h25
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h14
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc15
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/service.cc42
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc19
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc256
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc128
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc89
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc71
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h101
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc115
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc60
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc41
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc69
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc117
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc130
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc161
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc23
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h12
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h30
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc43
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc253
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc2
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h3
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc150
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc87
-rw-r--r--tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc46
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc149
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc62
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc123
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc184
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc308
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc172
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc74
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h12
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc204
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc152
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc66
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc11
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h4
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc17
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc2
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc17
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.h10
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc36
-rw-r--r--tensorflow/compiler/xrt/xrt_state.cc2
-rw-r--r--tensorflow/compiler/xrt/xrt_state.h2
147 files changed, 3797 insertions, 4195 deletions
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index bc2e640559..82e9eef005 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
TF_ASSIGN_OR_RETURN(auto literal,
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
+ LiteralToHostTensor(literal, arg.type, &arg.constant_value));
} else {
arg.kind = XlaCompiler::Argument::kParameter;
}
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 22a45b2a11..3d81ae9eb8 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
- xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
+ xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index c267848524..804671fbc7 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
+ literal = xla::LiteralUtil::CreateR0<uint8>(value);
break;
case xla::U32:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
+ literal = xla::LiteralUtil::CreateR0<uint32>(value);
break;
case xla::U64:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
+ literal = xla::LiteralUtil::CreateR0<uint64>(value);
break;
case xla::S8:
- literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
+ literal = xla::LiteralUtil::CreateR0<int8>(value);
break;
case xla::S32:
- literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
+ literal = xla::LiteralUtil::CreateR0<int32>(value);
break;
case xla::S64:
- literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
+ literal = xla::LiteralUtil::CreateR0<int64>(value);
break;
case xla::F32:
- literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
+ literal = xla::LiteralUtil::CreateR0<float>(value);
break;
case xla::F64:
- literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
+ literal = xla::LiteralUtil::CreateR0<double>(value);
break;
case xla::C64:
- literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
+ literal = xla::LiteralUtil::CreateR0<complex64>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
- literal = std::move(
- *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
break;
case xla::F16:
- literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
- static_cast<xla::half>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index 7dc16b5a46..ed452bceeb 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
{
std::vector<int64> int64_values = {1, 2, 3};
- std::unique_ptr<xla::Literal> int64_values_literal =
+ xla::Literal int64_values_literal =
xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
- LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
+ LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
.error_message());
- EXPECT_EQ(
- "Cannot convert literal of type S64 to tensor of type qint32",
- LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
- .error_message());
EXPECT_TRUE(
- LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
- .ok());
+ LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
test::ExpectTensorEqual<int64>(host_tensor,
test::AsTensor<int64>(int64_values));
}
@@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// Repeat tests with int32.
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
- std::unique_ptr<xla::Literal> int32_values_literal =
+ xla::Literal int32_values_literal =
xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
- .ok());
+ LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
test::ExpectTensorEqual<int32>(host_tensor,
test::AsTensor<int32>(int32_values));
EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
+ LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
.ok());
std::vector<qint32> qint32_values = {10, 11};
test::ExpectTensorEqual<qint32>(host_tensor,
test::AsTensor<qint32>(qint32_values));
EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
- LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
+ LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
.error_message());
}
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 56f7045a98..ab26d939cc 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
// Set up arguments.
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
- auto x_global_or = client->TransferToServer(*x_literal);
- auto y_global_or = client->TransferToServer(*y_literal);
+ auto x_global_or = client->TransferToServer(x_literal);
+ auto y_global_or = client->TransferToServer(y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
@@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
auto result_or =
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
- std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
- EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
+ xla::Literal result = std::move(result_or.ValueOrDie());
+ EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
config.mutable_feed(0)->mutable_id()->set_output_index(
123); /* invalid output_index */
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 40ce9fb41c..70efa7781d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
@@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler doesn't reorder the parameters.
@@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(
- xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
{
@@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR0<int32>(7);
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
@@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> input_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> input_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> input =
- xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*input).ConsumeValueOrDie();
+ client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> output_read =
- xla::LiteralUtil::CreateR0<int32>(42);
- std::unique_ptr<xla::Literal> output_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> output_grad1 =
- xla::LiteralUtil::CreateR1<int32>({0, 1});
- std::unique_ptr<xla::Literal> output_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
- {output_base.get(), output_grad1.get(), output_grad2.get()});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
+ xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
+ xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal output_resource =
+ xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({5, 144});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a simple graph that reads and writes a variable.
@@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
@@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index d1534e9a15..d10a504da0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
}
- xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
- compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
- &layout);
+ xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
+ constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
" as a compile-time constant.\nError: ",
computed.status().error_message());
}
- *constant_literal = std::move(*computed.ValueOrDie());
+ *constant_literal = std::move(computed).ValueOrDie();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 8818f81312..5dde5b432f 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
Client::~Client() = default;
-StatusOr<std::unique_ptr<Literal>> Client::Transfer(
- const GlobalData& data, const Shape* shape_with_layout) {
+StatusOr<Literal> Client::Transfer(const GlobalData& data,
+ const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
@@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
@@ -162,7 +162,7 @@ Status Client::ResetDevice() {
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
+StatusOr<Literal> Client::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
@@ -177,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
return Transfer(*data, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
- const XlaComputation& computation, const Layout* output_layout) const {
+StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
+ const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 7960b07868..6f4d33c469 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -96,8 +96,8 @@ class Client {
//
// If shape_with_layout is not nullptr, it points to a shape whose layout will
// be the layout of the returned literal.
- StatusOr<std::unique_ptr<Literal>> Transfer(
- const GlobalData& data, const Shape* shape_with_layout = nullptr);
+ StatusOr<Literal> Transfer(const GlobalData& data,
+ const Shape* shape_with_layout = nullptr);
// Transfer the given literal to the server. This allocates memory on the
// device and copies the literal's contents over. Returns a global data handle
@@ -122,7 +122,7 @@ class Client {
// device_handle and replica_id together specify a particular device; a device
// assigned for the given replica_id among the replicas that the given device
// handle belongs to.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
+ StatusOr<Literal> TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
@@ -132,7 +132,7 @@ class Client {
// Executes the computation with the given arguments and transfers the result
// to the client as a literal. Parameters are defined the same as for
// Execute() and Transfer().
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
@@ -153,7 +153,7 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ StatusOr<Literal> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 6861521acc..25cc37edc4 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -76,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (DataSizeOfShape(shape) < (1LL << 20)) {
- StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
+ StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via
// an on-device computation.
@@ -84,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
tensorflow::error::UNIMPLEMENTED);
return MakeFakeDataViaDeviceOrDie(shape, client);
}
- return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
+ return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
}
// If the data is large, generate it on-device.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 4402ba8762..f96b6c9c26 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments(
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*argument));
- *hlo_snapshot->add_arguments() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
+ *hlo_snapshot->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments(
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*result));
- *hlo_snapshot->mutable_result() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
+ *hlo_snapshot->mutable_result() = literal.ToProto();
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
+StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream,
backend_->BorrowStream(shaped_buffer.device_ordinal()));
@@ -277,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
return std::move(scoped_buffer);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
+StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
shaped_buffer.device_ordinal()));
@@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
literal);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal) {
+StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, shape, literal.get()));
+ executor, shape, &literal));
return std::move(literal);
}
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 56c3a3da02..feb2f8ec9d 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -84,8 +84,7 @@ class LocalExecutable {
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
- StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
@@ -132,8 +131,7 @@ class LocalClient : public Client {
// Copy the data from the device contained in the given ShapedBuffer and
// return as a Literal.
- StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
@@ -151,8 +149,8 @@ class LocalClient : public Client {
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal);
+ StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal);
// Returns the device ordinal that corresponds to the given replica number.
//
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 887b970661..4e1ff9e5c0 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -738,7 +738,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 58e8f4e7fa..833eafcf85 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -2112,12 +2112,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
- return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2129,44 +2129,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*LiteralUtil::CreateR1(values));
+ return ConstantLiteral(LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -2189,12 +2189,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2207,13 +2207,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -2221,14 +2221,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateFromArray<NativeT>(values));
+ LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -2236,15 +2235,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -2253,7 +2251,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3f7635bd40..f1f255efae 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
return *this;
}
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = absl::make_unique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
+Literal LiteralBase::CreateFromShape(const Shape& shape) {
+ Literal literal(shape);
+ literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
memset(piece->untyped_data(), 0, piece->size_bytes());
@@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<Literal>>
-MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
+/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
+ const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
- auto literal = absl::make_unique<Literal>(proto.shape());
+ Literal literal(proto.shape());
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
@@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
}
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
+Literal LiteralBase::Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
- auto result = absl::make_unique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
+ Literal result(new_shape);
+ TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
+Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
- result->shape(),
+ result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
+ TF_CHECK_OK(result.CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
}
});
return result;
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
@@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
- std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
- char* dest_data = static_cast<char*>(result->untyped_data());
+ char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
@@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
return std::move(result);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+StatusOr<Literal> LiteralBase::Reshape(
absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
- std::unique_ptr<Literal> output;
+ Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
} else {
- output = CloneToUnique();
+ output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- *output->mutable_shape_do_not_use() =
+ *output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
ShapeUtil::HumanString(shape()),
- ShapeUtil::HumanString(output->shape()));
+ ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- absl::Span<const int64> permutation) const {
+Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
- auto new_literal = absl::make_unique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ Literal new_literal(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
+Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
- auto result_literal = absl::make_unique<Literal>(result_shape);
+ Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- result_literal->EachCell<NativeT>(
+ result_literal.EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
- result_literal->Set<NativeT>(indices, value);
+ result_literal.Set<NativeT>(indices, value);
});
return result_literal;
}
-std::unique_ptr<Literal> LiteralBase::Slice(
- absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const {
+Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
return result;
}
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = absl::make_unique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
@@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
+Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
+ const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
+ Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
+ auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
@@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
+Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
@@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(
+ Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
- absl::Span<complex64> dest_data = result_literal->data<complex64>();
+ absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
+Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
@@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
}
template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
@@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
PrimitiveType_Name(primitive_dest_type));
}
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
+ return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
@@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
@@ -1362,8 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape,
+ bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
if (round_f32_to_bf16 && shape().element_type() == F32 &&
dest_shape.element_type() == BF16) {
@@ -1381,11 +1370,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
+ elements.push_back(std::move(new_element));
}
- auto converted = absl::make_unique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
- return std::move(converted);
+ return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index b928cb6374..fa5b5f7fab 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -223,25 +223,21 @@ class LiteralBase {
//
// TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
// the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+ StatusOr<Literal> ConvertToShape(const Shape& dest_shape,
+ bool round_f32_to_bf16 = false) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
// width. Returns an error if the conversion is not possible. This literal
// must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
+ // Clones the underlying buffers into a new Literal.
Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
// TODO(b/67651157): The methods below which perform computation on Literals
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@@ -259,24 +255,23 @@ class LiteralBase {
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
+ Literal Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
// An overload of Relayout which changes the layout of the entire shape rather
// than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+ Literal Relayout(const Shape& shape_with_layout) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape, absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Broadcast(const Shape& result_shape,
+ absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@@ -285,7 +280,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
- std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
+ Literal Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@@ -293,15 +288,15 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
- std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const;
+ Literal Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
// This literal must be an array.
template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
+ Literal Replicate(int64 times) const;
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
@@ -312,7 +307,7 @@ class LiteralBase {
// initialization, then reinitialization. Conside if a call to
// absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+ static Literal CreateFromShape(const Shape& shape);
protected:
// A data structure representing a subshape at a particular ShapeIndex within
@@ -539,8 +534,8 @@ class LiteralBase {
private:
template <typename NativeT>
- std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape, absl::Span<const int64> start_indices) const;
+ Literal SliceInternal(const Shape& result_shape,
+ absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@@ -687,8 +682,7 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
protected:
// Returns the piece at the given ShapeIndex.
@@ -1137,15 +1131,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+Literal LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal.shape());
if (elements == 0) {
return literal;
}
@@ -1157,7 +1150,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
bool done = false;
while (!done) {
const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
+ literal.Set<NativeT>(output_indices, element);
done = true;
for (int n = 0; n < output_indices.size(); ++n) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 1a64594db8..ba7fd29a62 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test {
Layout layout_r3_dim0minor_;
Layout layout_r4_dim0major_;
Layout layout_r4_dim0minor_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+ Literal literal_r4_2x2x3x3_dim0major_;
+ Literal literal_r4_2x2x3x3_dim0minor_;
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto true_lit = LiteralUtil::CreateR0<bool>(true);
- EXPECT_EQ("true", true_lit->ToString());
+ EXPECT_EQ("true", true_lit.ToString());
auto false_lit = LiteralUtil::CreateR0<bool>(false);
- EXPECT_EQ("false", false_lit->ToString());
+ EXPECT_EQ("false", false_lit.ToString());
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
- EXPECT_EQ("42", u32_lit->ToString());
+ EXPECT_EQ("42", u32_lit.ToString());
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
- EXPECT_EQ("-999", s32_lit->ToString());
+ EXPECT_EQ("-999", s32_lit.ToString());
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
- EXPECT_EQ("3.14", f32_lit->ToString());
+ EXPECT_EQ("3.14", f32_lit.ToString());
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
- EXPECT_EQ("0.5", f16_lit->ToString());
+ EXPECT_EQ("0.5", f16_lit.ToString());
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
- EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString());
+ EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
- EXPECT_EQ("0.5", bf16_lit->ToString());
+ EXPECT_EQ("0.5", bf16_lit.ToString());
// 3.14 will be rounded to 3.14062 in bfloat16 format.
auto bf16_lit_truncated =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- ASSERT_EQ("3.14062", bf16_lit_truncated->ToString());
+ ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
auto bf16_lit_truncated2 =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
- EXPECT_EQ("9", bf16_lit_truncated2->ToString());
+ EXPECT_EQ("9", bf16_lit_truncated2.ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
- EXPECT_EQ("{101}", pred_vec->ToString());
+ EXPECT_EQ("{101}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
@@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) {
{ 3, 4 },
{ 5, 6 }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R3ToString) {
@@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) {
{ { 5 },
{ 6 } }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, TupleToString) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -171,7 +171,7 @@ f32[2,2] {
{ 3, 4 }
}
))";
- EXPECT_EQ(expected, tuple->ToString());
+ EXPECT_EQ(expected, tuple.ToString());
}
TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
@@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
// clang-format on
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[2,3,2] {
{ { 1, 2 },
{ 3, 4 },
@@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) {
};
std::vector<int64> expected_values = {8, 9, 7, 10};
- EXPECT_EQ(literal->sparse_indices()->data(),
+ EXPECT_EQ(literal.sparse_indices()->data(),
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
- EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
+ EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
{2001, 2002},
}, /*projection_p=*/1, /*projection_z=*/2);
// clang-format on
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[1,2,3,2] {
{ /*i0=0*/
{ /*i1=0*/
@@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
}
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
- EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
+ EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
ElementsAre(2, 2, 3, 3));
- string result = literal_r4_2x2x3x3_dim0major_->ToString();
+ string result = literal_r4_2x2x3x3_dim0major_.ToString();
const string expected = R"(f32[2,2,3,3] {
{ /*i0=0*/
{ /*i1=0*/
@@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
});
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
- literal->EachCellAsString(
+ literal.EachCellAsString(
[&seen](absl::Span<const int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) {
auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
- EXPECT_EQ(*f32_42, *f32_42);
- EXPECT_EQ(*f32_42, *f32_42_clone);
+ EXPECT_EQ(f32_42, f32_42);
+ EXPECT_EQ(f32_42, f32_42_clone);
auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*f32_42, *f32_123);
+ EXPECT_NE(f32_42, f32_123);
auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
- EXPECT_NE(*f32_42, *f64_42);
+ EXPECT_NE(f32_42, f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
@@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(*matrix, *matrix);
- EXPECT_EQ(*matrix, *matrix_clone);
- EXPECT_NE(*matrix, *matrix_different);
- EXPECT_NE(*matrix, *vector_literal);
- EXPECT_NE(*matrix, *scalar);
- EXPECT_NE(*matrix, nil);
+ EXPECT_EQ(matrix, matrix);
+ EXPECT_EQ(matrix, matrix_clone);
+ EXPECT_NE(matrix, matrix_different);
+ EXPECT_NE(matrix, vector_literal);
+ EXPECT_NE(matrix, scalar);
+ EXPECT_NE(matrix, nil);
EXPECT_EQ(nil, nil);
}
@@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) {
auto token1 = LiteralUtil::CreateToken();
auto scalar = LiteralUtil::CreateR0<float>(1.0);
- EXPECT_EQ(*token0, *token1);
- EXPECT_NE(*token0, *scalar);
+ EXPECT_EQ(token0, token1);
+ EXPECT_NE(token0, scalar);
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
- *LiteralUtil::MakeTuple({token0.get()}));
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
- EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
+ LiteralUtil::MakeTuple({&token0}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&token1, &scalar}));
+ EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&scalar, &token1}));
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
- colmajor->Set<float>({0, 0}, 1.0);
- colmajor->Set<float>({0, 1}, 2.0);
- colmajor->Set<float>({1, 0}, 3.0);
- colmajor->Set<float>({1, 1}, 4.0);
+ Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+ colmajor.Set<float>({0, 0}, 1.0);
+ colmajor.Set<float>({0, 1}, 2.0);
+ colmajor.Set<float>({1, 0}, 3.0);
+ colmajor.Set<float>({1, 1}, 4.0);
- auto rowmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
- rowmajor->Set<float>({0, 0}, 1.0);
- rowmajor->Set<float>({0, 1}, 2.0);
- rowmajor->Set<float>({1, 0}, 3.0);
- rowmajor->Set<float>({1, 1}, 4.0);
+ Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+ rowmajor.Set<float>({0, 0}, 1.0);
+ rowmajor.Set<float>({0, 1}, 2.0);
+ rowmajor.Set<float>({1, 0}, 3.0);
+ rowmajor.Set<float>({1, 1}, 4.0);
- EXPECT_EQ(*rowmajor, *colmajor);
+ EXPECT_EQ(rowmajor, colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
- auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_EQ(*tuple1, *tuple2);
+ auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
+ EXPECT_EQ(tuple1, tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_NE(*tuple1, *reversed_tuple);
+ auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
+ EXPECT_NE(tuple1, reversed_tuple);
// Tuple with different value.
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
- auto different_tuple =
- LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_NE(*tuple1, *different_tuple);
+ auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
+ EXPECT_NE(tuple1, different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
@@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) {
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone =
LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- EXPECT_EQ(*vector, *vector_clone);
+ EXPECT_EQ(vector, vector_clone);
auto vector_reversed =
LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
- EXPECT_NE(*vector, *vector_reversed);
+ EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = LiteralUtil::CreateR0<float>(0.0);
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+ auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
// Tuples should always return false for IsAll.
- EXPECT_FALSE(tuple->IsAll(0));
- EXPECT_FALSE(tuple->IsAll(1));
+ EXPECT_FALSE(tuple.IsAll(0));
+ EXPECT_FALSE(tuple.IsAll(1));
}
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto scalar = LiteralUtil::CreateR0<float>(0.0);
auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
- auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_EQ(*tuple, *x);
+ auto x = Literal::CreateFromShape(tuple.shape());
+ EXPECT_EQ(tuple, x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
- EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
- EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
- ->IsAll(-1));
+ .IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
- ->IsAllFloat(.5));
+ .IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(
- LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
- EXPECT_TRUE(scalar_zero->IsZero({}));
- EXPECT_FALSE(scalar_one->IsZero({}));
+ EXPECT_TRUE(scalar_zero.IsZero({}));
+ EXPECT_FALSE(scalar_one.IsZero({}));
auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
- EXPECT_FALSE(array->IsZero({0, 1}));
- EXPECT_TRUE(array->IsZero({0, 2}));
- EXPECT_TRUE(array->IsZero({1, 1}));
- EXPECT_FALSE(array->IsZero({1, 2}));
+ EXPECT_FALSE(array.IsZero({0, 1}));
+ EXPECT_TRUE(array.IsZero({0, 2}));
+ EXPECT_TRUE(array.IsZero({1, 1}));
+ EXPECT_FALSE(array.IsZero({1, 2}));
auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
- EXPECT_TRUE(complex_zero->IsZero({}));
- EXPECT_FALSE(complex_nonzero->IsZero({}));
+ EXPECT_TRUE(complex_zero.IsZero({}));
+ EXPECT_FALSE(complex_nonzero.IsZero({}));
}
template <typename T>
@@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
- auto data01 = data->Relayout(layout01);
- EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_EQ(*data, *data01);
+ auto data01 = data.Relayout(layout01);
+ EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
+ EXPECT_EQ(data, data01);
- auto data10 = data->Relayout(layout10);
- EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_EQ(*data, *data10);
+ auto data10 = data.Relayout(layout10);
+ EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
+ EXPECT_EQ(data, data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Transpose(/*permutation=*/{});
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}});
// clang-format on
- auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
+ auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
- EXPECT_EQ(value, original->Get<float>(
+ reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
+ EXPECT_EQ(value, original.Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
}
@@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// Tests that using Relayout on an array is equivalent to creating it in the
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
- literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
+ literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
- literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
+ literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
- EXPECT_EQ(mat_dim0minor->element_count(), 6);
- EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
+ EXPECT_EQ(mat_dim0minor.element_count(), 6);
+ EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
// Test expected memory layout when using Relayout to row major.
- auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
- EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+ auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
+ EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
- EXPECT_EQ(mat_dim0major->element_count(), 6);
- EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
+ EXPECT_EQ(mat_dim0major.element_count(), 6);
+ EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout when using Relayout to column major.
- auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
- EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+ auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
+ EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
ElementsAre(1, 4, 2, 5, 3, 6));
}
@@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0minor_);
- EXPECT_EQ(lit_dim0minor->element_count(), 12);
+ EXPECT_EQ(lit_dim0minor.element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
- EXPECT_THAT(lit_dim0minor->data<int32>(),
+ EXPECT_THAT(lit_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
// Test expected memory layout when using Relayout to row major.
- auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
+ auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
+ EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0major_);
- EXPECT_EQ(lit_dim0major->element_count(), 12);
- EXPECT_THAT(lit_dim0major->data<int32>(),
+ EXPECT_EQ(lit_dim0major.element_count(), 12);
+ EXPECT_THAT(lit_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout when using Relayout to column major.
- auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
- EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
+ auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
+ EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
}
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = LiteralUtil::CreateR0<int32>(1);
- auto result = input->Slice({}, {});
- EXPECT_EQ(*input, *result);
+ auto result = input.Slice({}, {});
+ EXPECT_EQ(input, result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
- auto result = input->Slice({3}, {4});
+ auto result = input.Slice({3}, {4});
auto expected = LiteralUtil::CreateR1<float>({4.0});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto result = input_3x4->Slice({0, 2}, {2, 4});
+ auto result = input_3x4.Slice({0, 2}, {2, 4});
auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_EQ(*input_2x3x2, *result);
+ auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
+ EXPECT_EQ(input_2x3x2, result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
auto expected = LiteralUtil::CreateR1<int64>({77});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
@@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
@@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR0<bfloat16>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
@@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
@@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
auto expected = LiteralUtil::CreateR0<float>(2.5f);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
@@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
output.PopulateWithValue<complex64>({4, 2});
auto expected =
LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR0<half>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR1<half>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
auto input = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto output = input->Replicate<uint32>(3);
+ auto output = input.Replicate<uint32>(3);
auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_EQ(*output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, CopySliceFrom) {
@@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
auto init_proc = [&](absl::Span<const int64> indexes) {
- source->Set(indexes, ++seqnr);
+ source.Set(indexes, ++seqnr);
return true;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
init_proc);
auto blank = Literal::CreateFromShape(shape);
const int64 src_base[] = {3, 1, 5, 7};
const int64 dest_base[] = {6, 4, 12, 2};
const int64 copy_size[] = {7, 8, 11, 9};
- TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
+ TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
@@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
blank_indexes.begin(), std::plus<int64>());
- auto bval = blank->Get<uint32>(blank_indexes);
- matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
+ auto bval = blank.Get<uint32>(blank_indexes);
+ matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
return matched;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
check_proc);
EXPECT_TRUE(matched);
}
@@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
TEST_F(LiteralUtilTest, CopyFromScalars) {
auto zero = LiteralUtil::CreateR0<uint32>(0);
auto nine = LiteralUtil::CreateR0<uint32>(9);
- TF_EXPECT_OK(zero->CopyFrom(*nine));
- EXPECT_EQ(*zero, *nine);
+ TF_EXPECT_OK(zero.CopyFrom(nine));
+ EXPECT_EQ(zero, nine);
auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
- TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
- EXPECT_EQ(zero->Get<uint32>({}), 17);
- TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
- EXPECT_EQ(vect->Get<uint32>({4}), 17);
+ TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
+ EXPECT_EQ(zero.Get<uint32>({}), 17);
+ TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
+ EXPECT_EQ(vect.Get<uint32>({4}), 17);
}
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
@@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
- EXPECT_EQ(*nine, *const_nine);
+ TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
+ EXPECT_EQ(nine, const_nine);
}
{
// Copy 0 element to destination with zero elements.
- const auto empty = Literal::CreateFromShape(empty_r1_shape);
+ auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
- EXPECT_EQ(*empty, *const_empty);
+ TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
+ EXPECT_EQ(empty, const_empty);
}
}
@@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
TEST_F(LiteralUtilTest, CopyFromArrays) {
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*scalar_42, *scalar_123);
- TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*scalar_42, *scalar_123);
- EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+ EXPECT_NE(scalar_42, scalar_123);
+ TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(scalar_42, scalar_123);
+ EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
- EXPECT_NE(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
- TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+ EXPECT_NE(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
+ TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {matrix.get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get()});
+ Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0})};
+ Literal inner_tuple = LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal});
+ Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(-5).get(),
- LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
+ Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
+ Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
+ Literal tuple =
+ LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
// Overwrite the inner tuple element of nested_tuple with the contents of
// 'tuple'.
- TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{}));
+ TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
- LiteralUtil::CreateR0<int32>(4).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
+ LiteralUtil::CreateR0<int32>(4)};
+ Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
// Copy from one element to the other.
- TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{0}));
+ TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{0}));
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
- Status status = matrix->CopyFrom(*vector);
+ Status status = matrix.CopyFrom(vector);
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Destination subshape incompatible"));
@@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) {
// Verify that the internal data views are consistent and that they
// are in little endian format
// TODO - modify if we make the data format machine endianess dependent
- auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
- Literal* l1 = m1.get();
- const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
+ Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
+ const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
EXPECT_EQ(d1[0], 0);
EXPECT_EQ(d1[1], 0);
EXPECT_EQ(d1[2], 0);
@@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l2 = m2.get();
- const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
+ const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
EXPECT_EQ(d2[0], 0);
EXPECT_EQ(d2[1], 0x3C);
EXPECT_EQ(d2[2], 0);
@@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->Populate<uint32>(generator));
+ TF_EXPECT_OK(literal.Populate<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+ TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// clang-format on
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->Convert(U32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
- EXPECT_EQ(*expected, *converted);
+ EXPECT_EQ(expected, converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
- std::unique_ptr<Literal> conv;
+ Literal conv;
- conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u32);
+ conv = s8.Convert(U32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u32);
- conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = s8.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u64);
+ conv = s8.Convert(U64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u64);
- conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s64);
+ conv = s8.Convert(S64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s64);
- conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *pred);
+ conv = s8.Convert(PRED).ConsumeValueOrDie();
+ EXPECT_EQ(conv, pred);
- conv = bf16->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = bf16.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = bf16->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = bf16.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *int32_pred);
+ conv = pred.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, int32_pred);
- conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f32.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f64.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = s32.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f64.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = s32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = u32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = s32.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- conv = f16->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = f16.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- EXPECT_EQ(s32->Convert(TUPLE).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(S16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(U16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(F32).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(S32).status().code(),
+ EXPECT_EQ(s32.Convert(TUPLE).status().code(),
tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
}
TEST_F(LiteralUtilTest, BitcastConvert) {
@@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) {
tensorflow::bit_cast<uint32>(100.f), 0xbeef});
auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->BitcastConvert(F32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
auto literal = LiteralUtil::CreateR0<uint32>(1234);
- Status status = literal->BitcastConvert(F64).status();
+ Status status = literal.BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(
absl::StrContains(status.error_message(), "bit widths are different"));
@@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
p.add_preds((i % 2) == (len % 2));
}
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- ASSERT_EQ(len, literal->data<bool>().size());
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ ASSERT_EQ(len, literal.data<bool>().size());
int i = 0;
- for (bool value : literal->data<bool>()) {
+ for (bool value : literal.data<bool>()) {
EXPECT_EQ((i % 2) == (len % 2), value);
++i;
}
@@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h2(2.0f);
auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l = m.get();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->data<half>().size());
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
+ EXPECT_EQ(4, m.data<half>().size());
- LiteralProto p = l->ToProto();
+ LiteralProto p = m.ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
EXPECT_EQ(8, p.f16s().size());
const char* d = p.f16s().data();
@@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- auto r = literal->data<half>();
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ auto r = literal.data<half>();
ASSERT_EQ(4, r.size());
EXPECT_EQ(h1, r[0]);
EXPECT_EQ(h2, r[1]);
@@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
+ EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
+ EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 1.0f);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 1.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
1.0f);
- nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 555.0f);
+ nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 555.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
555.0f);
@@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view,
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
}
TEST_F(LiteralUtilTest, LiteralMove) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- Literal literal(std::move(*matrix));
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal(std::move(matrix));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get(),
- &nil_literal});
-
- EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
- std::vector<Literal> elements = nested_tuple->DecomposeTuple();
- EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+ Literal inner_elements[] = {
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}),
+ };
+ Literal tuple_elements[] = {
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
+ LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal}),
+ };
+ Literal nested_tuple = LiteralUtil::MakeTuple(
+ {&tuple_elements[0], &tuple_elements[1], &nil_literal});
+
+ EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape()));
+ std::vector<Literal> elements = nested_tuple.DecomposeTuple();
+ EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape()));
ASSERT_EQ(elements.size(), 3);
@@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
- elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
-
- ));
+ elements.push_back(LiteralUtil::CreateR0<float>(1.0));
+ elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
+ std::vector<Literal> inner_elements;
+ inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
+ inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
+ elements.push_back(
+ LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
Literal literal;
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- literal = std::move(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ literal = std::move(matrix);
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
}
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralSlice(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ const auto matrix_view = LiteralSlice(matrix);
LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
@@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(42.0).get(),
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
- tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
-
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
- 3.0);
- tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ Literal elements[] = {
+ LiteralUtil::CreateR0<float>(42.0),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ };
+ auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+ tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
+ tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
-4.0);
}
TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
// Literals constructed using CreateFromShape should be zero initialized.
- std::unique_ptr<Literal> scalar_f32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
- EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
- EXPECT_TRUE(scalar_f32->IsAll(0));
-
- std::unique_ptr<Literal> vector_s32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
- EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
- EXPECT_TRUE(vector_s32->IsAll(0));
-
- std::unique_ptr<Literal> tuple =
- Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
- ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
-
- EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
- EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
- EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
- EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+ Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+ EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
+ EXPECT_TRUE(scalar_f32.IsAll(0));
+
+ Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+ EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
+ EXPECT_TRUE(vector_s32.IsAll(0));
+
+ Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+ ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+ EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
+ EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
+ EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
+ EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
}
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@@ -1665,25 +1648,25 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto matrix_pred =
LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
auto tuple = LiteralUtil::MakeTuple(
- {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+ {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+ auto nested_tuple =
+ LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
- return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+ return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
};
- EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
- EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
- EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
- EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
- EXPECT_EQ(*tuple, to_from_proto(*tuple));
- EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+ EXPECT_EQ(one_f32, to_from_proto(one_f32));
+ EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
+ EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
+ EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
+ EXPECT_EQ(tuple, to_from_proto(tuple));
+ EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
- EXPECT_NE(*one_f32, *two_f32);
- EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+ EXPECT_NE(one_f32, two_f32);
+ EXPECT_NE(one_f32, to_from_proto(two_f32));
}
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
@@ -1802,11 +1785,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
TEST_F(LiteralUtilTest, SortSparseElements) {
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
SparseIndexArray(10, 3), {});
- literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
- literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
- literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
- literal->SortSparseElements();
- EXPECT_EQ(literal->ToString(false),
+ literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
+ literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
+ literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
+ literal.SortSparseElements();
+ EXPECT_EQ(literal.ToString(false),
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
}
@@ -1816,57 +1799,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
EXPECT_EQ(
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
"false");
EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(int64{2}));
EXPECT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(double{2.0}));
EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(static_cast<float>(half{2.0})));
EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
+ Literal literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
- /*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 613449cf10..0cb1ae35f4 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -45,7 +45,7 @@ using absl::StrCat;
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+Literal ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
@@ -56,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
@@ -67,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
- auto dest = result->data<ToNativeT>(shape_index);
+ auto dest = result.data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
- TF_CHECK_OK(result->CopyFrom(literal,
- /*dest_shape_index=*/shape_index,
- /*src_shape_index=*/shape_index));
+ TF_CHECK_OK(result.CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
}
}
});
@@ -83,53 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
+/* static */ Literal LiteralUtil::CreateFromDimensions(
PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
return Literal::CreateFromShape(
ShapeUtil::MakeShape(primitive_type, dimensions));
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
+/* static */ Literal LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
+/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
- return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
+/* static */ Literal LiteralUtil::CreateToken() {
+ return Literal(ShapeUtil::MakeTokenShape());
}
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(0));
+ return LiteralUtil::CreateR0<uint8>(0);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(0));
+ return LiteralUtil::CreateR0<uint32>(0);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(0));
+ return LiteralUtil::CreateR0<uint64>(0);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(0));
+ return LiteralUtil::CreateR0<int8>(0);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(0));
+ return LiteralUtil::CreateR0<int32>(0);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(0));
+ return LiteralUtil::CreateR0<int64>(0);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(0));
+ return LiteralUtil::CreateR0<float>(0);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(0));
+ return LiteralUtil::CreateR0<double>(0);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(0));
+ return LiteralUtil::CreateR0<complex64>(0);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -145,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(1));
+ return LiteralUtil::CreateR0<uint8>(1);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(1));
+ return LiteralUtil::CreateR0<uint32>(1);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(1));
+ return LiteralUtil::CreateR0<uint64>(1);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(1));
+ return LiteralUtil::CreateR0<int8>(1);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(1));
+ return LiteralUtil::CreateR0<int32>(1);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(1));
+ return LiteralUtil::CreateR0<int64>(1);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(1));
+ return LiteralUtil::CreateR0<float>(1);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(1));
+ return LiteralUtil::CreateR0<double>(1);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(1));
+ return LiteralUtil::CreateR0<complex64>(1);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -184,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- -std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- -std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -232,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity());
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -275,31 +261,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+/* static */ Literal LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
- auto literal = absl::make_unique<Literal>(
+ Literal literal(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
- absl::string_view value) {
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
+ Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
- literal->Set<uint8>({i}, value[i]);
+ literal.Set<uint8>({i}, value[i]);
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
- float from, float to, int64 rows, int64 cols) {
+/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
+ int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
+/* static */ Literal LiteralUtil::ReshapeSlice(
absl::Span<const int64> new_dimensions,
absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
int64 new_num_elements = 1;
@@ -309,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
- auto new_literal = absl::make_unique<Literal>(
+ Literal new_literal(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
- Shape shape_with_layout = new_literal->shape();
+ Shape shape_with_layout = new_literal.shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
@@ -326,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
- new_literal->Set<bool>(to_multi_index,
- literal.Get<bool>(from_multi_index));
+ new_literal.Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
break;
case U8:
- new_literal->Set<uint8>(to_multi_index,
- literal.Get<uint8>(from_multi_index));
+ new_literal.Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
break;
case U32:
- new_literal->Set<uint32>(to_multi_index,
- literal.Get<uint32>(from_multi_index));
+ new_literal.Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
break;
case S32:
- new_literal->Set<int32>(to_multi_index,
- literal.Get<int32>(from_multi_index));
+ new_literal.Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
break;
case U64:
- new_literal->Set<uint64>(to_multi_index,
- literal.Get<uint64>(from_multi_index));
+ new_literal.Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
break;
case S64:
- new_literal->Set<int64>(to_multi_index,
- literal.Get<int64>(from_multi_index));
+ new_literal.Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
break;
case F32:
- new_literal->Set<float>(to_multi_index,
- literal.Get<float>(from_multi_index));
+ new_literal.Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
break;
case F64:
- new_literal->Set<double>(to_multi_index,
- literal.Get<double>(from_multi_index));
+ new_literal.Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
break;
case C64:
- new_literal->Set<complex64>(to_multi_index,
- literal.Get<complex64>(from_multi_index));
+ new_literal.Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
@@ -376,97 +360,82 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
switch (literal.shape().element_type()) {
case PRED:
- return std::move(
- *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
+ return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
// 8 bit types.
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
+ return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
+ return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
// 16 bit types.
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- literal.GetFirstElement<bfloat16>()));
+ return LiteralUtil::CreateR0<bfloat16>(
+ literal.GetFirstElement<bfloat16>());
case F16:
- return std::move(
- *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
+ return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
case S16:
- return std::move(
- *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
+ return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
case U16:
- return std::move(
- *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
+ return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
// 32 bit types.
case F32:
- return std::move(
- *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
+ return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
+ return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
+ return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
// 64 bit types.
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(
- literal.GetFirstElement<complex64>()));
+ return LiteralUtil::CreateR0<complex64>(
+ literal.GetFirstElement<complex64>());
case F64:
- return std::move(
- *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
+ return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
+ return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
+ return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
default:
LOG(FATAL) << "Unhandled primitive type "
<< literal.shape().element_type();
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
+/* static */ Literal LiteralUtil::MakeTuple(
absl::Span<const Literal* const> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
+/* static */ Literal LiteralUtil::MakeTupleFromSlices(
absl::Span<const LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements) {
+/* static */ Literal LiteralUtil::MakeTupleOwned(
+ std::vector<Literal> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
for (const auto& element : elements) {
- element_shapes.push_back(element->shape());
+ element_shapes.push_back(element.shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int64 i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(
- literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
}
return literal;
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 2d6084a67a..2b181621ed 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -69,36 +69,34 @@ class LiteralUtil {
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR0(NativeT value);
+ static Literal CreateR0(NativeT value);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
- static std::unique_ptr<Literal> CreateR1(
- const tensorflow::core::Bitmap& values);
+ static Literal CreateR1(absl::Span<const NativeT> values);
+ static Literal CreateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2(
+ static Literal CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2WithLayout(
+ static Literal CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3(
- std::initializer_list<
- std::initializer_list<std::initializer_list<NativeT>>>
- values);
+ static Literal CreateR3(std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3WithLayout(
+ static Literal CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4(
+ static Literal CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4WithLayout(
+ static Literal CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -139,9 +137,10 @@ class LiteralUtil {
// [9, 10, 11]: 4.0
//
template <typename NativeT>
- static std::unique_ptr<Literal> CreateSparse(
- absl::Span<const int64> dimensions, SparseIndexArray indices,
- absl::Span<const NativeT> values, bool sort = true);
+ static Literal CreateSparse(absl::Span<const int64> dimensions,
+ SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@@ -155,130 +154,120 @@ class LiteralUtil {
static Literal MaxValue(PrimitiveType primitive_type);
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
+ static Literal CreateFullWithDescendingLayout(
absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+ static Literal CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
+ static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2D(
- const Array2D<NativeT>& values);
+ static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
+ static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3D(
- const Array3D<NativeT>& values);
+ static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
+ static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4D(
- const Array4D<NativeT>& values);
+ static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
+ static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
// Creates a new vector of U8s literal value from a string.
- static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
+ static Literal CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
- static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
- int64 rows, int64 cols);
+ static Literal CreateR2F32Linspace(float from, float to, int64 rows,
+ int64 cols);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3Projected(
+ static Literal CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4Projected(
+ static Literal CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
- static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+ static Literal MakeIdentityR2(int64 size);
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
- static std::unique_ptr<Literal> MakeTuple(
- absl::Span<const Literal* const> elements);
+ static Literal MakeTuple(absl::Span<const Literal* const> elements);
- static std::unique_ptr<Literal> MakeTupleFromSlices(
- absl::Span<const LiteralSlice> elements);
+ static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
- // std::vector<std::unique_ptr<Literal>> elements = ...;
+ // std::vector<Literal> elements = ...;
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements);
+ static Literal MakeTupleOwned(std::vector<Literal> elements);
- // This overload lets you pass a braced list of unique_ptr<Literal>s to
+ // This overload lets you pass a braced list of Literals to
// MakeTupleOwned:
//
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
- // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
+ // Simply relying on the MakeTupleOwned(std::vector<Literal>)
// overload doesn't work because std::initializer_list's elements are always
// const.
//
- // The arguments to this function must all be unique_ptr<Literal>.
+ // The arguments to this function must all be Literal.
template <typename... Ts>
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::unique_ptr<Ts>... elements) {
- std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
- std::move(elements)...};
- std::vector<std::unique_ptr<Literal>> v;
+ static Literal MakeTupleOwned(Ts... elements) {
+ std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
+ std::vector<Literal> v;
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
return MakeTupleOwned(std::move(v));
}
// Create a constant token literal. Token types have no value.
- static std::unique_ptr<Literal> CreateToken();
+ static Literal CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type, absl::Span<const int64> dimensions);
+ static Literal CreateFromDimensions(PrimitiveType primitive_type,
+ absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(
- const LiteralSlice& bf16_literal);
+ static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(
- const LiteralSlice& f32_literal);
+ static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
- static std::unique_ptr<Literal> ReshapeSlice(
- absl::Span<const int64> new_dimensions,
- absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
+ static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major,
+ const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@@ -286,7 +275,7 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ static StatusOr<Literal> CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator);
@@ -297,8 +286,8 @@ class LiteralUtil {
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
+ T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -307,8 +296,8 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
+ T stddev);
//
// End of factory methods.
@@ -322,44 +311,43 @@ class LiteralUtil {
std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
+/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
- literal->Set({}, value);
+ literal.Set({}, value);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
- absl::Span<const NativeT> values) {
- auto literal = absl::make_unique<Literal>(
+/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
+ Literal literal(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+/* static */ Literal LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateR2(values);
+ literal.PopulateR2(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+/* static */ Literal LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+/* static */ Literal LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -384,14 +372,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+/* static */ Literal LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+/* static */ Literal LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -422,23 +410,22 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
+/* static */ Literal LiteralUtil::CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
- indices.max_indices()));
- literal->PopulateSparse(indices, values, sort);
+ Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
+ indices.max_indices()));
+ literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -446,50 +433,48 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
+/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateFromArray(values);
+ literal.PopulateFromArray(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
+/* static */ Literal LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+/* static */ Literal LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+/* static */ Literal LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+/* static */ Literal LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -514,7 +499,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+/* static */ Literal LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -542,21 +527,20 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+/* static */ Literal LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -565,33 +549,29 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
- NativeT value) {
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
- literal->PopulateWithValue(value);
+/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
+ absl::Span<const int64> dimensions, NativeT value) {
+ Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+ literal.PopulateWithValue(value);
return literal;
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
- auto literal = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+ Literal literal(shape);
+ TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index f9473d372b..0f86f9f35e 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
-StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
- const Shape& shape, const Layout* layout) {
+StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
+ const Layout* layout) {
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
@@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
PrimitiveType_Name(shape.element_type()));
}
- auto result = absl::make_unique<Literal>(literal_shape);
- result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+ Literal result(literal_shape);
+ result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
- absl::Span<const float> field = result->data<float>();
+ absl::Span<const float> field = result.data<float>();
char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
absl::string_view sp;
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 98dccaa9a2..d6d2ff1521 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -41,8 +41,7 @@ class PackedLiteralReader {
//
// Layout is optional. If it is not provided, no layout is set on the literal
// that is produced.
- StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
- const Layout* layout = nullptr);
+ StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
// Returns whether the input file has been fully exhausted; i.e. all available
// packed literals have been read and we're at the end of the file.
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd6e20b693..9da5dc0d2d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
return client->TransferToInfeedLocal(literal, device_ordinal);
}
-StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number) {
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number) {
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
<< " shape: " << shape;
LocalClient* client = GetOrCreateLocalClient();
@@ -141,9 +141,8 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
LocalClient* client = GetOrCreateLocalClient();
StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- return ToBuffer(client, /*device_ordinal=*/0, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ return ToBuffer(client, /*device_ordinal=*/0, relaid);
}
return ToBuffer(client, /*device_ordinal=*/0, argument);
}();
@@ -151,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}
-StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
+StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
LocalClient* client = GetOrCreateLocalClient();
return client->ShapedBufferToLiteral(*shaped_buffer());
}
@@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation(
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
-StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
+StatusOr<Literal> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
@@ -169,7 +168,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
// Each replica populates a StatusOr result, but only replica zero actually
// retrieves its literal value.
- std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+ std::vector<StatusOr<Literal>> results(GetReplicaCount());
{
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
GetReplicaCount());
@@ -198,9 +197,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
StatusOr<ScopedShapedBuffer> pushed;
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- pushed = ToBuffer(client, device_ordinal, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ pushed = ToBuffer(client, device_ordinal, relaid);
} else {
pushed = ToBuffer(client, device_ordinal, argument);
}
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 78b3c598b9..1d5dfe5911 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
-StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number);
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number);
// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
@@ -65,7 +65,7 @@ class LocalShapedBuffer {
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
const ScopedShapedBuffer* shaped_buffer() const;
- StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
+ StatusOr<Literal> ToLiteral() const;
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
// analogous to std::unique_ptr::release().
@@ -117,7 +117,7 @@ class CompiledLocalComputation {
// with optionally-specified argument layouts. The literals will be
// re-laid out according to the corresponding elements of
// shapes_with_layout.
- StatusOr<std::unique_ptr<Literal> > Execute(
+ StatusOr<Literal> Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape> >& shapes_with_layout);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 450d3fe5af..521490e76c 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -216,9 +216,9 @@ tensorflow::ImportNumpy();
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
- std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
+ Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@@ -346,25 +346,25 @@ tensorflow::ImportNumpy();
// Literal
-%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
+%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
SWIG_fail;
}
- $1 = literal_status.ValueOrDie().get();
+ $1 = &literal_status.ValueOrDie();
}
-%typemap(out) std::unique_ptr<Literal> {
+%typemap(out) Literal {
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
- $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+ $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
@@ -375,13 +375,13 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
+ StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
Py_DECREF(o);
SWIG_fail;
}
- temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
+ temps.push_back(literal_status.ConsumeValueOrDie());
Py_DECREF(o);
}
$1 = &temps;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index fc6511bef5..b0aa024c74 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
}
}
-StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
if (PyTuple_Check(o)) {
int num_elements = PyTuple_Size(o);
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
elements.reserve(num_elements);
for (int i = 0; i < num_elements; i++) {
PyObject* element = PyTuple_GetItem(o, i);
@@ -389,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
int np_type = PyArray_TYPE(py_array);
auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
- TF_RETURN_IF_ERROR(
- CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
+ TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
return std::move(literal);
} else {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 8cae175185..40ff2d9ad2 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
-StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
// The following functions copy array data from the buffers underlying Numpy
// ndarrays into those underlying XLA literals, and vice versa.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 9f1afa2671..05325367f5 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -529,13 +529,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@@ -546,7 +546,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim;
dim.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
@@ -556,7 +556,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim2;
dim2.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
@@ -565,7 +565,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
*window.add_dimensions() = dim2;
const Shape& shape = ShapeInference::InferConvolveShape(
- lhs_literal->shape(), rhs_literal->shape(),
+ lhs_literal.shape(), rhs_literal.shape(),
/*feature_group_count=*/1, window, dnums)
.ConsumeValueOrDie();
@@ -585,18 +585,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
- CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
+ CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
auto result =
- absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
- result_literal->shape().dimensions(1),
- result_literal->shape().dimensions(2),
- result_literal->shape().dimensions(3));
+ absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
+ result_literal.shape().dimensions(1),
+ result_literal.shape().dimensions(2),
+ result_literal.shape().dimensions(3));
result->Each([&](absl::Span<const int64> indices, float* value) {
- *value = result_literal->Get<float>(indices);
+ *value = result_literal.Get<float>(indices);
});
return result;
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 3ec0192148..a1b0f4045f 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MatmulArray2D) {
@@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
- LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({0}, result);
}
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
+ LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
ErrorSpec(0.0001));
}
@@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray4D) {
@@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray3D) {
@@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
+ {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
ErrorSpec(0.0001));
}
@@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
- *actual_literal, ErrorSpec(0.0001));
+ {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray4D) {
@@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
@@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
{{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
@@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
} // namespace
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 43fd8fe1bd..84fe5b17d1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<float>(expected);
+ Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+ EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
ErrorSpec(0.0001)));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 3d18fe3be2..2a0823aeca 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -527,7 +527,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ HloInstruction::CreateConstant(literal.Clone()));
}
}
@@ -546,7 +546,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
+ Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -676,7 +676,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
- HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ HloInstruction::CreateConstant((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@@ -1469,7 +1469,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
}
@@ -1572,7 +1572,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- LiteralUtil::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1607,7 +1607,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -2062,7 +2062,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (!converted_pad_literal.ok()) {
return false;
}
- return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ return converted_pad_literal.ValueOrDie() == reduce_init_literal;
};
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
@@ -2223,8 +2223,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(convolution->shape().element_type())
- .CloneToUnique())),
+ LiteralUtil::Zero(convolution->shape().element_type()))),
{}));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index a0db4563fb..3fc1ba2427 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector)};
+ Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec281ae68f..30d33e0d35 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
const Shape feature_shape = scale->shape();
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
const Shape feature_shape = scale->shape();
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
@@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
@@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto elements_per_feature_literal =
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
+ elements_per_feature_literal.Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 388fd5df99..e032b5c624 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index c30abd1d3e..795beb9ff5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
+ Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()})));
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 414bfe7999..17e5090505 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()});
- auto inner_tuple1 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+ Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
+ auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
+ Literal element1 = LiteralUtil::CreateR0<int64>(3);
+ auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- inner_tuple0->shape(), tuple_constant, 0));
+ inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 0826380f65..0ac4a65ec6 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
- auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+ auto zero = add(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
auto new_filter = add(
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 6bf3810967..1deb412064 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
- Shape vshape = input_literal1->shape();
+ Shape vshape = input_literal1.shape();
auto input1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal1)));
@@ -78,13 +78,13 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -125,8 +125,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
- error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
@@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -213,7 +212,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
- *result, error_spec_);
+ result, error_spec_);
}
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c35569c661..5cc6d01c0f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
// Send 5 Infeed data of shape F32[3].
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 3 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
}
// Tests two Infeed operations with a total order. The order is enforced by
@@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
+ LiteralUtil::CreateR0<bool>(false)})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
+ LiteralUtil::CreateR0<bool>(false)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 6 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index bb105194f1..7af51db55a 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 1b3be199f6..852f34e06d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -56,9 +56,9 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
- RunTest(hlo_text, {lhs.get(), rhs.get()});
+ Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {&lhs, &rhs});
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 4ed91ef187..bec02e14f9 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
const auto subliteral = LiteralSlice(literal, index);
- std::unique_ptr<Literal> relayed_out_literal;
+ Literal relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
@@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
- source = relayed_out_literal->untyped_data();
+ source = relayed_out_literal.untyped_data();
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream,
/*size=*/GetByteSizeRequirement(device_subshape), source,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index bda8ebe579..d237f8930b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
Array4D<float> constant_arr(4, 4, 2, 2);
constant_arr.FillIota(0);
string constant_str =
- LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+ LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
ParseAndVerifyModule(absl::StrFormat(R"(
HloModule test
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index fa84d77223..b0061fa655 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
static constexpr int64 kDesiredNumFeaturesFactor = 8;
@@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
HloComputation* comp = instr->parent();
const Shape& shape = instr->shape();
- auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+ auto* zero = comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 9d85d746d8..2a6415d0b6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
conv_window.dimensions(i).base_dilation() - 1);
}
PrimitiveType element_type = input->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloComputation* computation = kernel->parent();
PrimitiveType element_type = kernel->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(input->shape().element_type()))));
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(input->shape().element_type())));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index 4550f36fdf..780539c164 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
TEST_F(GpuCopyTest, UseMemcpy) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
builder.AddInstruction(HloInstruction::CreateUnary(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index 9072b30317..f8120a5fa0 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
@@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
TEST_F(InfeedTest, LargeInfeed) {
Array4D<float> array(80, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
+ TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
Array4D<float> array(40, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
- LiteralUtil::CreateR0<int32>(5).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR4FromArray4D<float>(array),
+ LiteralUtil::CreateR0<int32>(5)}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 8a45939c61..f837816cea 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
- std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+ Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
- if (result == nullptr) {
+ if (!evaluator->TryEvaluate(instruction, &result)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 07cd1efc12..4da42844bd 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->Literal::CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+ matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index a3fcc0fefa..b76c50bb5b 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(operand->shape().element_type()))));
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(operand->shape().element_type())));
return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index eb6affadc8..e07a196d11 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
entry_computation->set_root_instruction(first_1_dims_collapsed);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module,
{LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<int32>(
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
entry_computation->set_root_instruction(zero_padded_param);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index e09d5868f2..9b18b0284f 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index a2f683b690..064b86493d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -54,9 +54,8 @@ namespace xla {
namespace {
template <typename OperandT>
-StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
}));
@@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
template <>
-StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
@@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
@@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
- .CloneToUnique();
+ .Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloModule& module, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(module, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
@@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
}
TF_RETURN_IF_ERROR(computation.Accept(this));
- return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
+ return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloComputation& computation, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(computation, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
@@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
- evaluated_[operand] = input_literal->CloneToUnique();
+ evaluated_[operand] = input_literal->Clone();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal : arg_literals) {
+ arg_literal_ptrs.push_back(&literal);
+ }
+ return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction) {
+StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
@@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
}
-std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
- HloInstruction* instruction) {
+bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
+ CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
- return nullptr;
+ return false;
}
- return result_or.ConsumeValueOrDie();
+ *result = result_or.ConsumeValueOrDie();
+ return true;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
+StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
@@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
- HloInstruction::CreateConstant(it->second->CloneToUnique()));
+ HloInstruction::CreateConstant(it->second->Clone()));
}
}
@@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
- HloInstruction::CreateConstant(operand.CloneToUnique());
+ HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
@@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
- evaluated_[parameter] = input_literal->CloneToUnique();
+ evaluated_[parameter] = input_literal->Clone();
return Status::OK();
}
@@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
@@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64 index_vector_dim, const Literal& start_indices,
- std::unique_ptr<Literal>* reshaped_start_indices) {
+ Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
@@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
start_indices.Reshape(new_shape));
- return std::cref(**reshaped_start_indices);
+ return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
- std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+ Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_start_indices;
+ Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
@@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
- result->CopyElementFrom(operand, input_index, output_index));
+ result.CopyElementFrom(operand, input_index, output_index));
return true;
};
@@ -977,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = absl::make_unique<Literal>(
- ShapeUtil::GetTupleElementShape(operand->shape(), index));
- return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
- /*dest_shape_index=*/{},
- /*src_shape_index=*/{index});
+ evaluated_[get_tuple_element] =
+ Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
+ return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
+ /*dest_shape_index=*/{},
+ /*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
-
- auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
- evaluated_[copy] = std::move(result);
+ evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
@@ -1004,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
.ConsumeValueOrDie();
@@ -1036,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator
.Evaluate<const Literal*>(*readded_computation, arg_literals)
.ConsumeValueOrDie();
@@ -1056,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
auto* false_computation = conditional->false_computation();
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result;
+ Literal result;
if (pred.Get<bool>({})) {
result = embedded_evaluator
.Evaluate<const Literal*>(*true_computation,
@@ -1081,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
- evaluated_[select] = on_true.CloneToUnique();
+ evaluated_[select] = on_true.Clone();
} else {
- evaluated_[select] = on_false.CloneToUnique();
+ evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
@@ -1097,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
- evaluated_[tuple_select] = on_true.CloneToUnique();
+ evaluated_[tuple_select] = on_true.Clone();
} else {
- evaluated_[tuple_select] = on_false.CloneToUnique();
+ evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
@@ -1108,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
- auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
+ auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
@@ -1118,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
while_hlo->name(), max_loop_iterations_);
}
- TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
- *cond_comp, {lcv.get()}));
- keep_going = cond_val->GetFirstElement<bool>();
+ TF_ASSIGN_OR_RETURN(auto cond_val,
+ cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
+ keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
- *body_comp, {lcv.get()}));
- VLOG(3) << "Loop iteration result: " << body_val->ToString();
+ *body_comp, {&lcv}));
+ VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
@@ -1139,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
// hoops to make this work.
namespace {
template <typename KeyType, typename ValueType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@@ -1179,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
- auto result_values_literal =
- absl::make_unique<Literal>(values_literal.shape());
- result_values_literal->PopulateR1(
+ Literal result_keys_literal(keys_literal.shape());
+ result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal result_values_literal(values_literal.shape());
+ result_values_literal.PopulateR1(
absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
- std::unique_ptr<Literal> result_tuple;
+ Literal result_tuple;
if (rank == 1) {
auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple = LiteralUtil::MakeTuple(
- {result_pair.first.get(), result_pair.second.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
- auto values_result_literal =
- absl::make_unique<Literal>(values_literal.shape());
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
+ .Reshape({r1_length}));
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
values_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+ .Reshape({r1_length}));
+ auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first->Reshape({1, r1_length}));
+ r1_result_pair.first.Reshape({1, r1_length}));
TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
- *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
- *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+ r1_result_pair.second.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
}
- result_tuple = LiteralUtil::MakeTuple(
- {keys_result_literal.get(), values_result_literal.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
}
- VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
template <typename KeyType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@@ -1248,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
}
}
-StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
- const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(0)->shape().element_type()) {
case F32:
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@@ -1319,28 +1344,14 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
// Explicit instantiation of templatized Evaluate* methods.
//
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloModule& module, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
-
-template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
- const Literal*>(const HloComputation& computation,
- absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
+
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloComputation& computation,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- HloInstruction* instruction,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 72252bafc7..21e676d671 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Precondition: The indices of arg_literals correspond to the parameter
// numbers of the HLO parameters in the computation. See comment below for an
// example.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloModule& module,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
// 1 in this computation. The input literals array will then have its first
// literal map to Parameter0 and the second map to Parameter1.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloComputation& computation,
- absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
- StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction);
- // Same as Evaluate, except returning nullptr on error.
- std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+ // Same as Evaluate, except returning false on error and accepts an output
+ // pointer.
+ bool TryEvaluate(HloInstruction* instruction, Literal* result);
// Evaluates a single HLO instruction, substituting the given literals for
// some of the instruction's operands.
//
// For example, given instruction = op(A, B, C) and the map
// {A = x, C = y}, this evaluates op(x, B, y).
- StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
+ StatusOr<Literal> EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
- HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+ StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
+ const Literal& lhs,
+ const Literal& rhs);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
- HloOpcode opcode, const Literal& operand);
+ StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
+ const Literal& operand);
- StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers,
- const PrecisionConfig& precision_config, const Literal& lhs,
- const Literal& rhs);
+ StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
+ const Literal& lhs, const Literal& rhs);
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
- return *(it->second);
+ return it->second;
}
// Tracks the HLO instruction and its evaluated literal result.
@@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// that are no longer a parent for any other subsequent instruction in
// post-orderring.
// Must be cleared for each evaluation.
- tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
- evaluated_;
+ // Storing Literal in place require the container to have pointer stability so
+ // we cannot use FlatMap any more.
+ std::unordered_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
- static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ static StatusOr<Literal> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
@@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
ShapeUtil::HumanString(operand->shape()));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 102ebb24ab..16411eb078 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
evaluator_ = absl::make_unique<HloEvaluator>();
}
- std::unique_ptr<Literal> Evaluate(
- absl::Span<const Literal* const> arg_literals = {}) {
+ Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<HloEvaluator> evaluator_;
- void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> input, float aabs = 0) {
+ void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
+ float aabs = 0) {
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
- b.AddInstruction(
- HloInstruction::CreateUnary(expected->shape(), opcode, c1));
+ b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto element_type = expected->shape().element_type();
+ auto element_type = expected.shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
}
- void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> lhs,
- std::unique_ptr<Literal> rhs) {
+ void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
+ Literal rhs) {
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
b.AddInstruction(
- HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
+ HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
bool use_bfloat16_;
@@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- Shape shape = low->shape();
+ Shape shape = low.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
auto high = LiteralUtil::CreateR0<float>(1.f);
- Shape shape = value->shape();
+ Shape shape = value.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- Shape shape = on_true->shape();
+ Shape shape = on_true.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
auto c2 =
@@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
- std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
+ std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
lhs_instruction, param_rhs2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate(args);
+ Literal result = Evaluate(args);
auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies Reshape operation is correctly evaluated.
@@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
@@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
- result->EachCell<NativeT>(
- [&](absl::Span<const int64> indices, NativeT value) {
- std::vector<int64> rindexes = Permute(permutation, indices);
- EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
- });
+ result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
+ });
}
// Verifies Broadcast operation is correctly evaluated.
@@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction, {1, 2}));
+ output_literal.shape(), literal_instruction, {1, 2}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloInstruction::CreateConstant(std::move(input_literal)));
// Broadcast dimension should be empty in the case of scalars.
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction,
+ output_literal.shape(), literal_instruction,
/*broadcast_dimensions=*/{}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int64>(
{{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<int64>({100, 200});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
- ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
- ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
PaddingConfig CreatePaddingConfig(
@@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
shape, operand_instruction, padding_value_instruction, padding_config));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
@@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
@@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected_array = Array2D<float>({
@@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = Array2D<float>({
{22.f, 28.f},
@@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 4, 4);
// clang-format off
@@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 7, 7);
expected_array.FillWithYX(Array2D<float>({
@@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 8, 8);
expected_array.FillWithYX(Array2D<float>({
@@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest,
@@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 9, 3);
expected_array.FillWithYX(Array2D<float>({
@@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest,
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
@@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
std::iota(input_elems.begin(), input_elems.end(), -7);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
std::iota(filter_elems.begin(), filter_elems.end(), -31);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
@@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
/*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 1, 8);
expected_array.FillWithYX(
Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
module().AddEntryComputation(b.Build());
HloEvaluator hlo_eval;
- std::unique_ptr<Literal> result =
- hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+ Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
}
// Reducing many numbers should be fast because it doesn't create
@@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({6, 18});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
@@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
/*strides=*/{2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
@@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
shape, operand, update, start_indices));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto result_inner_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
- auto expected = LiteralUtil::MakeTuple({
- result_inner_literal.get(),
- result_inner_literal.get(),
- });
+ auto expected =
+ LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Reverse) {
@@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected = LiteralUtil::CreateR4FromArray4D<float>({
@@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
+ Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, &param0_literal}, {square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
- auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
+ auto result =
+ evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1906,12 +1901,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1930,12 +1925,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1954,14 +1949,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>(
+ LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1980,15 +1974,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest,
@@ -2008,15 +2001,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -2035,12 +2027,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -2059,13 +2050,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2084,11 +2074,10 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2108,12 +2097,12 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> start_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2138,15 +2127,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
@@ -2171,15 +2158,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
@@ -2205,15 +2191,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
@@ -2239,15 +2223,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) {
@@ -2273,17 +2255,15 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(
+ LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()}),
- ErrorSpec{0.1, 0.01}));
+ Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01}));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
@@ -2309,15 +2289,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
@@ -2343,15 +2321,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
@@ -2376,21 +2353,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
{{-40, 40}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest,
@@ -2416,21 +2390,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
@@ -2455,16 +2426,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
@@ -2489,17 +2458,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
@@ -2524,13 +2490,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *operand,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ operand, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
@@ -2557,16 +2521,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<int32>({10, 61, 32});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -2603,11 +2564,10 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>(
+ Literal arg = LiteralUtil::CreateR1<bfloat16>(
{bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()})));
+ Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
}
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 63303aef1e..7f090a52db 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -246,15 +246,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
parent_->evaluated_[convert] = std::move(result);
} else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
+ parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
}
return Status::OK();
}
@@ -262,15 +261,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
parent_->evaluated_[convert] = std::move(result);
} else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
+ parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
}
return Status::OK();
}
@@ -978,10 +976,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1157,8 +1155,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@@ -1231,9 +1229,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = absl::make_unique<Literal>(dot->shape());
+ Literal result(dot->shape());
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1280,8 +1278,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ Literal result(pad->shape());
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
@@ -1289,7 +1287,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@@ -1311,8 +1309,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
+ result.Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@@ -1439,16 +1437,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = absl::make_unique<Literal>(map->shape());
+ Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@@ -1463,16 +1461,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
+ Literal computed_result =
+ embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
- return computed_result->Get<ReturnT>({});
+ return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@@ -1557,9 +1553,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ Literal result_literal(keys_literal.shape());
+ result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@@ -1568,16 +1564,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result = sort_r1(*r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
- *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ .Reshape({r1_length}));
+ auto r1_result = sort_r1(r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@@ -1651,9 +1647,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
- results[i] = absl::make_unique<Literal>(result_shape);
+ results[i] = Literal(result_shape);
}
Status eval_status;
@@ -1667,7 +1663,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
for (int64 input = 0; input < num_args; ++input) {
- TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
[&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
@@ -1703,8 +1699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
// Evaluate computation with specified literal operands.
- absl::InlinedVector<std::unique_ptr<Literal>, 1>
- embedded_operands;
+ absl::InlinedVector<Literal, 1> embedded_operands;
for (ReturnT value : result_values) {
embedded_operands.push_back(
LiteralUtil::CreateR0<ReturnT>(value));
@@ -1717,11 +1712,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_operands.size());
std::transform(embedded_operands.begin(), embedded_operands.end(),
embedded_operands_ptrs.begin(),
- [](const std::unique_ptr<Literal>& ptr) {
- return ptr.get();
- });
+ [](Literal& literal) { return &literal; });
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ TF_ASSIGN_OR_RETURN(Literal computed_result,
embedded_evaluator.Evaluate<const Literal*>(
*function, embedded_operands_ptrs));
// Clear visit states so that we can use the evaluator again on
@@ -1729,10 +1722,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
// Assign computed result to result_val.
if (!has_tuple_output) {
- result_values[0] = computed_result->Get<ReturnT>({});
+ result_values[0] = computed_result.Get<ReturnT>({});
} else {
for (int64 i = 0; i < num_args; ++i) {
- result_values[i] = computed_result->Get<ReturnT>(
+ result_values[i] = computed_result.Get<ReturnT>(
/*multi_index=*/{}, /*shape_index=*/{i});
}
}
@@ -1748,9 +1741,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!has_tuple_output) {
parent_->evaluated_[reduce] = std::move(results[0]);
} else {
- auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ Literal tuple_result(reduce->shape());
for (int64 i = 0; i < num_args; ++i) {
- TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
parent_->evaluated_[reduce] = std::move(tuple_result);
}
@@ -1781,10 +1774,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(select_and_scatter->shape());
+ Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
@@ -1834,15 +1827,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
- curr_val_literal->Set({}, curr_val);
- selected_val_literal->Set({}, *selected_val);
- std::unique_ptr<Literal> computed_result =
+ curr_val_literal.Set({}, curr_val);
+ selected_val_literal.Set({}, *selected_val);
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *select,
- {selected_val_literal.get(), curr_val_literal.get()})
+ *select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
+ bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@@ -1856,16 +1848,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- source_literal_scatter->Set({}, source);
- scattered_literal->Set({}, scattered);
- std::unique_ptr<Literal> computed_result =
+ auto scattered = result.Get<ReturnT>(operand_index);
+ source_literal_scatter.Set({}, source);
+ scattered_literal.Set({}, scattered);
+ Literal computed_result =
embedded_evaluator
- .Evaluate<const Literal*>(*scatter,
- {source_literal_scatter.get(),
- scattered_literal.get()})
+ .Evaluate<const Literal*>(
+ *scatter,
+ {&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1916,10 +1908,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce_window->shape());
+ Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1935,18 +1927,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
+ *function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
- result_val = computed_result->Get<ReturnT>({});
+ result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@@ -1961,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
- std::unique_ptr<Literal>* reshaped_indices) {
+ Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@@ -1970,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
- return std::cref(**reshaped_indices);
+ return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -2230,7 +2221,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
- std::unique_ptr<Literal> reshaped_scatter_indices;
+ Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@@ -2260,7 +2251,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
- std::unique_ptr<Literal> result = operand.CloneToUnique();
+ Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
[&](absl::Span<const int64> update_window_index,
@@ -2299,19 +2290,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
- LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+ LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
- std::unique_ptr<Literal> updated_result =
+ Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
- {result_value_literal.get(), update_value_literal.get()})
+ {&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
- result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+ result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
@@ -2361,7 +2352,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@@ -2575,7 +2566,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (ShapeUtil::Rank(iota->shape()) > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
- result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
parent_->evaluated_[iota] = std::move(result);
@@ -2645,9 +2636,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
+ StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+ const Literal& start_indices_literal,
+ const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@@ -2660,9 +2651,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2676,12 +2667,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
+ StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+ const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
+ const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2689,15 +2680,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
- result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+ result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
+ result.Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
return true;
};
@@ -2710,7 +2701,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@@ -2723,7 +2714,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@@ -2745,10 +2736,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2757,7 +2748,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@@ -2782,10 +2773,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f06c98f2e7..85fa3ce964 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
- instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+ instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
@@ -527,7 +527,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
- std::unique_ptr<Literal> literal) {
+ Literal literal) {
return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index bf25157395..4f6cac1396 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -359,8 +359,7 @@ class HloInstruction {
const string& name);
// Creates a literal constant instruction.
- static std::unique_ptr<HloInstruction> CreateConstant(
- std::unique_ptr<Literal> literal);
+ static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
// Creates an Iota instruction.
static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index fb7345a2ad..e92882c22a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
-HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
- : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+HloConstantInstruction::HloConstantInstruction(Literal literal)
+ : HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- if (literal_ != nullptr) {
+ if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
return proto;
@@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
if (!mutable_array_subshape->has_layout() ||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
- literal_ = literal_->Relayout(new_layout, shape_index);
+ *literal_ = literal_->Relayout(new_layout, shape_index);
*mutable_array_subshape->mutable_layout() = new_layout;
}
}
@@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
+ CHECK(literal_.has_value());
+ return absl::make_unique<HloConstantInstruction>(literal_->Clone());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if (literal_ != nullptr &&
+ if (literal_.has_value() &&
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
@@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstructionProto HloTraceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_literal() = literal_->ToProto();
+ *proto.mutable_literal() = literal_.ToProto();
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index c3a7801164..2d7bc83855 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
- explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ explicit HloConstantInstruction(Literal literal);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns whether there is literal associated with this instruction.
- bool HasLiteral() const { return literal_ != nullptr; }
+ bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
- string TracingTag() const { return literal_->GetR1U8AsString(); }
+ string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ Literal literal_;
};
class HloFusionInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index c54360b063..11caa89c54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -105,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
- bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
- bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseLiteral(Literal* literal, const Shape& shape);
+ bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+ bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
- bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
- std::unique_ptr<Literal> literal;
+ Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
- std::vector<std::unique_ptr<Literal>> elements(
- ShapeUtil::TupleElementCount(shape));
+ std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
- linear_index++, literal->get())) {
+ linear_index++, literal)) {
return false;
}
lexer_.Lex();
@@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
- *literal = (*literal)->Relayout(shape.layout());
+ *literal = literal->Relayout(shape.layout());
return true;
}
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = absl::make_unique<Literal>(shape);
+ *literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
- if ((*literal)->sparse_element_count() + 1 ==
+ if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
- (*literal)->AppendSparseElement(index, value);
+ literal->AppendSparseElement(index, value);
}
- (*literal)->SortSparseElements();
+ literal->SortSparseElements();
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 66ac1f66fd..fa7f216321 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals) {
+ const absl::Span<const Literal> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
- literal_pointers.push_back(literal.get());
+ literal_pointers.push_back(&literal);
}
return TransferLiteralsToDevice(literal_pointers);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
+StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
TF_ASSIGN_OR_RETURN(
auto stream, backend().BorrowStream(backend().default_stream_executor()));
@@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
buffer);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+StatusOr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
@@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return TransferLiteralFromDevice(result);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
- argument_pointers.push_back(argument.get());
+ argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
@@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
/*profile=*/profile);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(
@@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = absl::make_unique<Literal>();
+ Literal literal;
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, options.outfeed_shape, literal.get()));
+ executor, options.outfeed_shape, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}
@@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
- std::vector<std::unique_ptr<Literal>> exec_results;
+ std::vector<Literal> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 76d8b92bed..2e934bf66a 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -72,7 +72,7 @@ class HloRunner {
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
- std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+ std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
@@ -106,24 +106,23 @@ class HloRunner {
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals);
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- const ShapedBuffer& buffer);
+ const absl::Span<const Literal> literals);
+ StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const Literal* const> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal* const> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
@@ -140,7 +139,7 @@ class HloRunner {
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
- StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+ StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0cac210c24..8f0423bb1c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
padding_config));
auto module = CreateNewModule();
@@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
padding_config));
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 37b774b8a5..06f0e1ed25 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> inner_broadcast_result,
+ Literal inner_broadcast_result,
broadcast_const_operand->literal().Broadcast(
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
@@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+ opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+ opcode, inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 9746d176cc..df9cbab915 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
}
}
- Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+ Literal* TakeOwnership(Literal literal) {
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
- StatusOr<Literal*> TakeOwnership(
- StatusOr<std::unique_ptr<Literal>> literal_or_error) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- std::move(literal_or_error));
+ StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
std::vector<std::unique_ptr<Array>> owned_tensors_;
- std::vector<std::unique_ptr<Literal>> owned_literals_;
+ std::vector<Literal> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
};
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 5695bc2420..93a74dbfa6 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
// Test that `constant` function is changed to `broadcast`.
@@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 5dea124768..a06d6113e8 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// Transform the ShapedBuffer arguments into literals which the evaluator
// consumes.
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
// Execute the graph using the HloEvaluator.
- std::unique_ptr<Literal> result_literal;
+ Literal result_literal;
{
tensorflow::mutex_lock lock(evaluator_lock_);
- TF_ASSIGN_OR_RETURN(result_literal,
- evaluator_->Evaluate<std::unique_ptr<Literal>>(
- *computation, arg_literals));
+ TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
+ *computation, arg_literals));
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
transfer_manager->AllocateScopedShapedBuffer(
- result_literal->shape(), run_options->allocator(),
+ result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- run_options->stream(), *result_literal, result));
+ run_options->stream(), result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 69c7e42601..f8baba03c3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
- Shape ashape = constant_literal1->shape();
+ Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index f0e2566a3f..922ebdf0e3 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, *argument));
- *module->add_arguments() = literal->ToProto();
+ *module->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, result));
- *module->mutable_result() = literal->ToProto();
+ *module->mutable_result() = literal.ToProto();
return Status::OK();
}
@@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> result_literal,
+ Literal result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
- if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
- result_literal->shape())) {
- *result->mutable_literal() = result_literal->ToProto();
+ if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
+ *result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
- result_literal->Relayout(*return_shape)->ToProto();
+ result_literal.Relayout(*return_shape).ToProto();
}
return Status::OK();
}
@@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
@@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- stream.get(), *literal, shaped_buffer));
+ stream.get(), literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
- executor, *literal);
+ return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
+ literal);
}
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
- executor, arg->shape_with_layout(), *literal));
- *result->mutable_literal() = literal->ToProto();
+ executor, arg->shape_with_layout(), literal));
+ *result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModule::CreateFromProto(arg->computation(), config));
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(auto result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, /*arg_literals=*/{}));
+ TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
+ *module, /*arg_literals=*/{}));
// Since the result layout is non-effective to the Evaluator results, explicit
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
- result_literal = result_literal->Relayout(arg->output_layout());
+ result_literal = result_literal.Relayout(arg->output_layout());
}
- *result->mutable_literal() = result_literal->ToProto();
+ *result->mutable_literal() = result_literal.ToProto();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index b8d2d546e5..a21e586efa 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
@@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferLiteralFromDevice(
@@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
return substream->BlockHostUntilDone();
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+StatusOr<Literal> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
@@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferArrayToDevice(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 21725946b3..f952e64af2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -57,7 +57,7 @@ class TransferManager {
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
- virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ virtual StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
virtual Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
@@ -113,9 +113,9 @@ class TransferManager {
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source);
+ StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
+ const Shape& shape,
+ const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2b2a2eb42a..e9a07b14ed 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+ Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})};
+ auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index c3c2603c7e..541b117e02 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
+ StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
- std::unique_ptr<Literal> indvar_iter_val =
- std::move(indvar_init_result).ValueOrDie();
+ Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(
- while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
+ if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
- StatusOr<std::unique_ptr<Literal>> indvar_next_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update,
- {{while_body_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 0bf4556b43..c257566fb2 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -41,7 +41,6 @@ limitations under the License.
namespace xla {
namespace {
-
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001, 0.0001};
@@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<uint64> rhs{1,
0x7FFFFFFFFFFFFFFLL,
@@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Add(lhs_param, rhs_param);
@@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<int64> rhs{-1,
0,
@@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Sub(lhs_param, rhs_param);
@@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
XlaBuilder b(TestName());
std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
Lt(lhs_param, rhs_param);
- ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
+ ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
}
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
@@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a_constant = ConstantR1<float>(&builder, a_values);
- auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
+ auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+ Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
auto b_param = ConstantR1<float>(&builder, b_values);
auto sum1 = Add(a_constant, b_constant);
@@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+ Literal param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto sum = ConstantR0<float>(&b, 0.0f);
- auto param = Parameter(&b, 0, param_literal->shape(), "param");
+ auto param = Parameter(&b, 0, param_literal.shape(), "param");
for (float exponent : exponents) {
sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
}
@@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Pow(Exp(param0), param1);
std::vector<float> expected(values0.size());
@@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Log(Pow(param0, param1));
std::vector<float> expected(values0.size());
@@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Mul(Exp(param0), Exp(param1));
std::vector<float> expected(values0.size());
@@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Div(param0, Exp(param1));
std::vector<float> expected(values0.size());
@@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(Div(param0, param1), param2);
std::vector<float> expected(values0.size());
@@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Div(param1, param2));
std::vector<float> expected(values0.size());
@@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Pow(param1, param2));
std::vector<float> expected(values0.size());
@@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
+ Literal literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
- client_->TransferToServer(*literal3).ConsumeValueOrDie();
+ client_->TransferToServer(literal3).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
- auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
+ auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
Div(Div(param0, param1), Div(param2, param3));
std::vector<float> expected(values0.size());
@@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
@@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
Array3D<float> expected(0, 7, 0);
@@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
- auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
Add(a, p);
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
@@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
-0.79, 1.41, 1.21, 1.05});
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Tanh(input);
ComputeAndCompareR1<float>(
@@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Exp(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::exp(input_literal->Get<float>({i})));
+ expected_result.push_back(std::exp(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Log(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::log(input_literal->Get<float>({i})));
+ expected_result.push_back(std::log(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
- LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
@@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- auto a = ConstantLiteral(&builder, *a_literal);
+ Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ auto a = ConstantLiteral(&builder, a_literal);
auto b = ConstantR1<float>(&builder, r1);
Add(a, b, {1});
@@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
- auto y = Parameter(&builder, 1, y_literal->shape(), "y");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
+ auto y = Parameter(&builder, 1, y_literal.shape(), "y");
auto slice = Slice(x, {1}, {2}, {1});
Sub(slice, y);
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index ac90a3adb6..bc2ba151a3 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
+ input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(
- {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({0, 0}).get(),
- LiteralUtil::CreateR1<float>({16, 20}).get()});
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
+ LiteralUtil::CreateR1<float>({0, 0}),
+ LiteralUtil::CreateR1<float>({16, 20})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
@@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
- auto expected = LiteralUtil::MakeTuple(
- {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
- LiteralUtil::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_normalized, LiteralUtil::CreateR1<float>(mean),
+ LiteralUtil::CreateR1<float>(var)});
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
BatchNormTraining(input_activations, scale_activations, offset_activations,
epsilon, feature_index);
@@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
- &builder, *expected,
+ &builder, expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
@@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
+ auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
auto variance_activations =
- Parameter(&builder, 4, var_literal->shape(), "variance");
+ Parameter(&builder, 4, var_literal.shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
BatchNormInference(input_activations, scale_activations, offset_activations,
mean_activations, variance_activations, epsilon,
@@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
auto grad_output_literal =
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
- auto input_parameter =
- Parameter(&builder, 0, input_literal->shape(), "input");
- auto scale_parameter =
- Parameter(&builder, 1, scale_literal->shape(), "scale");
- auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
- auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
+ auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
+ auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
+ auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
+ auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
auto grad_output_parameter =
- Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
+ Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> var_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> grad_output_data =
- client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+ client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
grad_output_parameter, epsilon, feature_index);
- auto expected =
- LiteralUtil::MakeTuple({expected_grad_activation.get(),
- LiteralUtil::CreateR1<float>(grad_scale).get(),
- LiteralUtil::CreateR1<float>(grad_offset).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
+ LiteralUtil::CreateR1<float>(grad_offset)});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 65589b0d6a..e9728e636f 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
- {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
- .get(),
+ {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
- .get()});
+ {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
- {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
- .get(),
+ {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
- .get()});
+ {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index fe4267c73b..dde19fb65d 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
- client_->TransferToServer(*r3_data).ConsumeValueOrDie();
+ client_->TransferToServer(r3_data).ConsumeValueOrDie();
return r3_global_data;
}
@@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
- client_->TransferToServer(*r2_data).ConsumeValueOrDie();
+ client_->TransferToServer(r2_data).ConsumeValueOrDie();
return r2_global_data;
}
@@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
@@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R3ImplicitBroadcastSpec {
@@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
- {r3_implicit_global_data.get(), r3_global_data.get()},
+ &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
ErrorSpec(1e-7, 1e-7));
}
@@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
+ ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R2ImplicitBroadcastSpec {
@@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
+ &builder, expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
r2_implicit_global_data2.get()},
ErrorSpec(1e-6, 1e-6));
@@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 74d4d2eb10..9966e4606e 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
error_spec_));
}
@@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralSlice(*result, {0}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ LiteralSlice(result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralSlice(*result, {1}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ LiteralSlice(result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
error_spec_));
}
@@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
error_spec_));
}
@@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
- *result, error_spec_));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index b1d18210ea..8b31e53707 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
- auto constant =
- ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
+ auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
- auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
- auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
+ auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
auto x =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
auto y =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
@@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
auto elem = LiteralUtil::CreateR0<float>(42.0);
- auto tuple = LiteralUtil::MakeTuple({elem.get()});
- Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
+ auto tuple = LiteralUtil::MakeTuple({&elem});
+ Call(&builder, callee, {ConstantLiteral(&builder, elem)});
- ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
+ ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index a4eb57fc7b..2f1510ff69 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
- auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
Add(p0, p1);
auto param0_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto param1_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
@@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
auto computation = computation_status.ConsumeValueOrDie();
auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
- auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+ auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
- client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+ client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
- auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+ auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
// Match
auto status = client_->Execute(
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 8a236db0ff..fbdf0fcb65 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
return client_->Execute(computation, arguments, &execution_options_);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
&execution_options);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
@@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>>
-ClientLibraryTestBase::ExecuteAndTransferReference(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
if (!result.ok()) {
return result.status().ToString();
} else {
- return result.ValueOrDie()->ToString();
+ return result.ValueOrDie().ToString();
}
}
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const string& error_message)>& verify_output) {
// Try with no layout requirement.
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
- verify_output(*actual, "");
+ verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
- verify_output(*actual,
+ verify_output(actual,
absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
@@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_ASSIGN_OR_RETURN(auto literal,
client_->Transfer(*arguments[index], nullptr));
// Skip tuples because they don't have a rank.
- if (ShapeUtil::IsTuple(literal->shape())) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal->shape()));
+ ShapeUtil::HumanStringWithLayout(literal.shape()));
arguments_with_layout.push_back(arguments[index]);
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
- std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
+ std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =
- literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
+ literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
+ ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
TF_ASSIGN_OR_RETURN(auto data,
- client_->TransferToServer(*literal_relayout));
+ client_->TransferToServer(literal_relayout));
arguments_with_layout.push_back(data.get());
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
for (const auto& str : layout_strings) {
absl::StrAppend(&error_message, str, " ");
}
- verify_output(*actual, error_message);
+ verify_output(actual, error_message);
return Status::OK();
};
@@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
return Status::OK();
}
@@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
return Status::OK();
}
@@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+ Literal expected_literal = LiteralUtil::CreateR1U8(expected);
- VLOG(1) << "expected: " << expected_literal->ToString();
- VLOG(1) << "actual: " << actual->ToString();
+ VLOG(1) << "expected: " << expected_literal.ToString();
+ VLOG(1) << "actual: " << actual.ToString();
- EXPECT_EQ(expected, actual->GetR1U8AsString());
+ EXPECT_EQ(expected, actual.GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
}
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+StatusOr<std::pair<Literal, Literal>>
ClientLibraryTestBase::ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
@@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return ConstantLiteral(builder, use_bfloat16_
- ? *LiteralUtil::ConvertF32ToBF16(literal)
- : literal);
+ ? LiteralUtil::ConvertF32ToBF16(literal)
+ : LiteralSlice(literal));
}
std::unique_ptr<GlobalData>
@@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
const Literal& literal) {
if (use_bfloat16_) {
- return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ return LiteralUtil::ConvertF32ToBF16(literal);
}
return literal.Clone();
}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 22dfdfb0e4..9d32f4f517 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test {
StatusOr<std::unique_ptr<GlobalData>> Execute(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// This executes the computation via the reference client (which connects a
// interpreter backend). The result is used as the expected values of the
// computation.
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
+ StatusOr<Literal> ExecuteAndTransferReference(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
- return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
+ return AddParam(LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+ return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
+ return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
// Executes the computation and calculates the expected reference value using
// the reference client. Returns two literals in the order of (expected,
// actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(XlaBuilder* builder,
- absl::Span<const Literal> arguments);
+ StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
+ XlaBuilder* builder, absl::Span<const Literal> arguments);
Client* client_;
Client* ref_client_; // To compute reference result.
@@ -412,9 +411,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -438,9 +435,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, absl::Span<const NativeT> expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -464,9 +459,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -490,9 +485,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -516,9 +511,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -542,13 +537,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR0(value);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -556,13 +551,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR1(values);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -570,13 +565,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -584,13 +579,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index c898dacf48..6f2ca84bb6 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
- auto computed, client_->Transfer(*data, &expected_literal->shape()));
+ auto computed, client_->Transfer(*data, &expected_literal.shape()));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
@@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- LiteralSlice(*result, {0}));
+ LiteralSlice(result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- LiteralSlice(*result, {1}));
+ LiteralSlice(result, {1}));
- EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
- EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+ EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 0),
+ ShapeUtil::GetTupleElementShape(result.shape(), 0),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 1),
+ ShapeUtil::GetTupleElementShape(result.shape(), 1),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})));
}
@@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
+ LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
Add(Parameter(&b, 0, shape, "param_0"),
@@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
- client_->Transfer(*results[0], &expected_result->shape()));
+ client_->Transfer(*results[0], &expected_result.shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 03d5696499..6ef7ca035f 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase {
absl::Span<GlobalData* const> arguments,
float expected_result, bool expect_cache_hit) {
ExecutionProfile execution_profile;
- std::unique_ptr<Literal> result =
+ Literal result =
client_
->ExecuteAndTransfer(computation, arguments,
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
->Execute(computation, arguments,
&execution_options_, &execution_profile)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data_handle).ConsumeValueOrDie();
+ Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
@@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
- client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
- client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 8226b6de3f..3b0414a604 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
- StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
- Client* client, const XlaOp& operand, XlaBuilder* builder,
- Layout* output_layout = nullptr) {
+ StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
+ XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
@@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
builder, nullptr));
- return literal->Get<Scalar>({});
+ return literal.Get<Scalar>({});
}
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<int32>({4, 6});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index be017477d8..9811a015e9 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, f32_scalar, "x");
@@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "z");
auto bcast = Broadcast(y, {5});
@@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "y");
auto y_bcast = Broadcast(y, {1, 5, 7});
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 25d10ab00a..32cac499c7 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
- LiteralUtil::CreateR0<float>(25.0f).get()}),
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
+ LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
@@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
- LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<bool>(true).get(),
- LiteralUtil::CreateR0<float>(12.2f).get(),
- LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(true),
+ LiteralUtil::CreateR0<float>(12.2f),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(46.6f).get(),
- LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
- .get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
- LiteralUtil::CreateR0<float>(9.3f).get()})
- .get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(46.6f),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
+ LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
@@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
- LiteralUtil::CreateR0<float>(b).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
@@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is true case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
@@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is false case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 4937574831..72ff1e74a4 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4D(input_array);
+ Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *input_literal);
+ ConstantLiteral(&builder, input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
@@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+ ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})}));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
+ Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
- LiteralSlice(*result, {0}), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+ LiteralSlice(result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateToken());
+ ConstantLiteral(&builder, LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 7a203d6873..5f063e6784 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, U32);
@@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 38b6da4fa9..fd98bf29b8 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
- client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+ client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index e0a1538850..070b092d18 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
- auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
+ auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
- auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+ auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r5).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r5,
+ ComputeAndCompareLiteral(&builder, expected_r5,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
@@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(param0)),
- std::move(*LiteralUtil::CreateFromArray(param1))},
+ {LiteralUtil::CreateFromArray(param0),
+ LiteralUtil::CreateFromArray(param1)},
error_spec_);
}
@@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
- auto expected_r3 =
- expected_r1->Reshape({batch, num_windows, output_feature})
- .ConsumeValueOrDie();
+ auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
+ .ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+ client_->TransferToServer(input_r3).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r3,
+ client_->TransferToServer(filter_r3).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
@@ -891,9 +890,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
Array4D<float> filter_data(1, 1, 1, 2);
filter_data.FillIota(10);
- ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))});
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)});
}
XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
@@ -928,8 +926,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
/*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
/*feature_group_count=*/64);
- ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data))},
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 6784c16715..ba3e9c436e 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
- weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto weights = ConstantLiteral(&builder, *weights_literal);
+ weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto weights = ConstantLiteral(&builder, weights_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto mirrored_weights = Rev(weights, {2, 3, 4});
ConvWithGeneralPadding(gradients, mirrored_weights,
/*window_strides=*/{1, 1, 1},
/*padding=*/{{0, 0}, {0, 0}, {1, 1}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
@@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
- activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
- auto activations = ConstantLiteral(&builder, *activations_literal);
+ activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
+ auto activations = ConstantLiteral(&builder, activations_literal);
auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto forward_conv =
ConvGeneralDilated(activations, gradients,
@@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder::CreateDefaultConvDimensionNumbers(
/*num_spatial_dims=*/3));
Transpose(forward_conv, {0, 1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 526626c1dd..1407e68d9a 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase {
protected:
void TestCopyOp(const Literal& literal) {
auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ auto constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase {
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
- TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+ TestCopyOp(LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
- TestCopyOp(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4(
+ TestCopyOp(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+ TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
@@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
// Copy literal to device to use as parameter.
auto literal = LiteralUtil::CreateR0<float>(42.0);
- Shape shape = literal->shape();
+ Shape shape = literal.shape();
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
@@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(module), {literal.get()});
- LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {&literal});
+ LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
@@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
- Layout* literal_layout =
- literal->mutable_shape_do_not_use()->mutable_layout();
+ Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
@@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
// The result of the computation has the default layout, which is the inverse
// of the layout of the source literal.
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
error_spec_);
}
@@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+ LiteralTestUtil::ExpectR3EqualArray3D(a, result);
}
void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
@@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+ Literal literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+ LiteralTestUtil::ExpectR4EqualArray4D(a, result);
}
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
@@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
XlaBuilder builder(TestName());
Parameter(&builder, 0, in_shape, "input");
- auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
+ auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index d12a4e7fcd..410732c07b 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
- EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+ EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
}
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
@@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(
- *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
}
// On the GPU backend, constants get special handling. Someone might pass a
@@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 6f7fc0e6e5..a693fa3595 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
@@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest,
@@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
module->AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D<float>(
- Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+ Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
class CustomCallClientAPITest : public ClientLibraryTestBase {};
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index eb15fc0593..e0f23b0fa8 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
// Try copying the elements back and comparing it
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
@@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
auto handles1 = result_status1.ConsumeValueOrDie();
auto handles2 = result_status2.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
handles1[0].reset();
handles1[1].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
@@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
// the same as handle[3] and handle[1] should be the same as handle[2].
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
@@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
// should not have been deallocated because of reference counting.
global_data.reset();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
/// Try deallocating one of the repeated elements, then copy
handles[0].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
@@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 5873516442..0171f51583 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
+ LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
auto lhs_handle =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
- param.dot_lhs_row_major)));
+ Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
- client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
- std::unique_ptr<Literal> dot_rhs_lit =
+ Literal dot_rhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
- client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> addend_data;
- std::unique_ptr<Literal> addend_lit;
+ Literal addend_lit;
std::unique_ptr<GlobalData> addend_handle;
if (param.has_addend) {
@@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() {
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
- addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
+ addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
}
XlaBuilder builder(TestName());
@@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
@@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
.ConsumeValueOrDie();
@@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
auto lhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*lhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*rhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
@@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 9bf3767ca3..7501c6d957 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie();
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*LiteralUtil::CreateR0(input_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(input_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_value =
- std::move(*LiteralUtil::CreateR0(update_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(update_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_value =
- std::move(*LiteralUtil::CreateR0(expected_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(expected_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
absl::Span<const int> expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR1(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR3FromArray3D<NativeT>(values);
- LOG(INFO) << name << ":" << literal->ToString();
+ Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+ LOG(INFO) << name << ":" << literal.ToString();
}
};
@@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- auto input = ConstantLiteral(&builder, *input_literal);
+ auto input = ConstantLiteral(&builder, input_literal);
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
@@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- stream.get(), *start_indices_literal, buffer));
+ stream.get(), start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index 5116e60ca6..b08ece0e63 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
- *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+ LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index bf1de02ba9..738f2600d4 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -38,7 +38,7 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 7cb2f0cedf..9c94acb437 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -117,9 +117,9 @@ class FusionTest : public HloTestBase {
auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
}
@@ -222,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.
@@ -248,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -283,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
// Every element of result should be y = x^2 = 4.0.
for (int i = 0; i < rand_dim0_size; ++i) {
for (int j = 0; j < dim1_size; ++j) {
- EXPECT_EQ(4.0, result->Get<float>({i, j}));
+ EXPECT_EQ(4.0, result.Get<float>({i, j}));
}
}
}
@@ -308,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -323,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -338,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -353,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -368,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -383,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__) {
@@ -398,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -413,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -428,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -443,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reverse) {
@@ -459,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -477,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -495,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
@@ -513,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -535,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -552,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, TransposeNegate) {
@@ -570,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -602,8 +602,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kInput);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -624,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -674,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
// When a constant (or other op) which has multiple users is imported
@@ -710,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
@@ -782,19 +782,17 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> operand =
- LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+ Literal operand = LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text, config));
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
- test_runner_.Execute(std::move(module), {operand.get()},
- /*run_hlo_passes=*/false));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
+ test_runner_.Execute(std::move(module), {&operand},
+ /*run_hlo_passes=*/false));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
- *result));
+ LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ result));
}
class FusionClientLibraryTest : public ClientLibraryTestBase {};
@@ -821,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
// where overflow is OK.
Array2D<uint32> arr(32, 32);
arr.FillUnique();
- std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({0, 1}));
- std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({1, 0}));
- XlaOp p0 = AddParam(*l1, &b);
+ XlaOp p0 = AddParam(l1, &b);
XlaOp sum = p0;
for (int i = 1; i < kNumParams; ++i) {
- auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+ auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
sum = sum + p0 * pN * pN;
}
@@ -879,19 +877,19 @@ void BM_ParallelFusion(int num_iters) {
auto param0_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
- client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
- client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
- client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param2_literal, device_ordinal)
.ConsumeValueOrDie();
// Build executable.
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 6d63498044..daa89398a6 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -58,10 +58,10 @@ ENTRY main {
slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -79,10 +79,10 @@ ENTRY main {
slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -100,11 +100,10 @@ ENTRY main {
slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -122,11 +121,11 @@ ENTRY main {
slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -144,11 +143,11 @@ ENTRY main {
slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -166,13 +165,12 @@ ENTRY main {
slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -190,13 +188,12 @@ ENTRY main {
slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -214,10 +211,10 @@ ENTRY main {
slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -235,11 +232,10 @@ ENTRY main {
slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -257,9 +253,9 @@ ENTRY main {
slice_sizes={1, 0}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -281,11 +277,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -307,11 +303,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
+ Literal start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -333,11 +329,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -359,11 +355,11 @@ ENTRY main {
ROOT result = u32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -381,10 +377,10 @@ ENTRY main {
slice_sizes={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -402,9 +398,9 @@ ENTRY main {
slice_sizes={1}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -422,10 +418,10 @@ ENTRY main {
slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -446,10 +442,10 @@ ENTRY main {
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -470,11 +466,10 @@ ENTRY main {
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -495,11 +490,11 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -520,13 +515,12 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest,
@@ -548,13 +542,12 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -575,10 +568,10 @@ ENTRY main {
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -599,11 +592,10 @@ ENTRY main {
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -640,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> operand_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
@@ -657,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<xla::GlobalData>> result_data,
client_->ExecuteParallel(computation_instances));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
client_->Transfer(*(result_data[0])));
- LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}},
- *result_literal);
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}}, result_literal);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 3df99aac7d..bdd4fd7e3d 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -136,21 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() {
return debug_options;
}
-StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments);
}
-std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_
.Execute(std::move(module), arguments,
/*run_hlo_passes=*/false)
.ValueOrDie();
}
-std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
}
@@ -188,7 +188,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
TF_ASSIGN_OR_RETURN(auto reference,
reference_runner_.Execute(std::move(reference_module),
arguments, run_hlo_passes));
- return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
+ return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
error);
}
@@ -223,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompare(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -243,7 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -277,7 +276,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return test_runner_
.Execute(std::move(module_or_status.ValueOrDie()),
fake_argument_ptrs, /*run_hlo_passes=*/true)
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 21d77c0cc4..0ae4bdc104 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -115,16 +115,16 @@ class HloTestBase : public ::testing::Test {
}
// Executes the given module and return the result as a Literal.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Same as above, except the module will be executed without running any HLO
// passes on it.
- std::unique_ptr<Literal> ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
- std::unique_ptr<Literal> ExecuteAndTransfer(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Executes the given hlo module on two backends and compares results.
//
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 96f72212f3..43cca91f64 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -155,20 +155,20 @@ class LiteralTestUtil {
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
absl::Span<const NativeT> expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
@@ -176,46 +176,46 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
absl::Span<const NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -223,7 +223,7 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -232,28 +232,28 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index 4151bfae03..b6f9b8156b 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -31,11 +31,11 @@ namespace xla {
namespace {
TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal literal = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
// death assertion.
auto unequal_things_are_equal = [] {
- std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal lhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(64).get(),
- LiteralUtil::CreateR0<int32>(42).get(),
+ Literal rhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(64),
+ LiteralUtil::CreateR0<int32>(42),
});
- CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+ CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
};
ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
}
@@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
auto two = LiteralUtil::CreateR0<float>(2);
auto four = LiteralUtil::CreateR0<float>(4);
ErrorSpec error(0.001);
- CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+ CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
};
tensorflow::Env* env = tensorflow::Env::Default();
@@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal_proto));
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
if (result.find("expected") != string::npos) {
- EXPECT_EQ("2", literal->ToString());
+ EXPECT_EQ("2", literal.ToString());
} else if (result.find("actual") != string::npos) {
- EXPECT_EQ("4", literal->ToString());
+ EXPECT_EQ("4", literal.ToString());
} else if (result.find("mismatches") != string::npos) {
- EXPECT_EQ("true", literal->ToString());
+ EXPECT_EQ("true", literal.ToString());
} else {
FAIL() << "unknown file in temporary directory: " << result;
}
@@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
- ::testing::AssertionResult result =
- LiteralTestUtil::Equal(*expected, *actual);
+ ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
EXPECT_THAT(result.message(),
::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
EXPECT_THAT(result.message(),
@@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
@@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
@@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b =
LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
- EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
- EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 237a4a361e..dbdd20daf0 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
int64 allocation_count_before = allocator_->allocation_count();
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
DefaultExecutableBuildOptions(), options);
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
@@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
computation, {}, ExecutableBuildOptions().set_device_ordinal(d),
ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 1a823cf189..a99b43f469 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
auto y = ConstantR0<float>(&builder, 123.0f);
Add(x, y);
- auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
+ auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
- LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
auto y = ConstantR1<float>(&builder, {});
Add(x, y);
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
- LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
@@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ExecutionProfile profile;
ScopedShapedBuffer result = ExecuteLocallyOrDie(
builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions().set_execution_profile(&profile));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
}
@@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
@@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
ScopedShapedBuffer result_colmaj =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with the parameter values in a different order.
ScopedShapedBuffer result_param_swap =
ExecuteLocallyOrDie(computation, {&y_array, &x_array});
- LiteralTestUtil::ExpectR2Near<float>(
- {{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_param_swap), error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ ShapedBufferToLiteral(result_param_swap),
+ error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
@@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
// Run with col-major result layout.
ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with row-major result layout.
@@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_rowmaj),
+ ShapedBufferToLiteral(result_rowmaj),
error_spec_);
}
@@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {2}));
+ LiteralSlice(result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 0}));
+ LiteralSlice(result_literal, {0, 0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {0, 1}));
+ LiteralSlice(result_literal, {0, 1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 2}));
+ LiteralSlice(result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
Tuple(&builder, {x, y});
auto array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
ExecutableBuildOptions options = DefaultExecutableBuildOptions();
Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
options, DefaultExecutableRunOptions());
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
Tuple(&builder, {array_sum, vector_diff});
auto computation = builder.Build().ConsumeValueOrDie();
- auto x_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
- auto y_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
- LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+ auto x_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
+ auto y_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
+ LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
- auto x_buffer = LiteralToShapedBuffer(*x_literal);
- auto y_buffer = LiteralToShapedBuffer(*y_literal);
+ auto x_buffer = LiteralToShapedBuffer(x_literal);
+ auto y_buffer = LiteralToShapedBuffer(y_literal);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
@@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
Tuple(&builder, {negate_array, vector_sum});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
+ LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
+ Literal result_0_literal = ShapedBufferToLiteral(result_0);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralSlice(*result_0_literal, {0}));
+ LiteralSlice(result_0_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
- LiteralSlice(*result_0_literal, {1}));
+ LiteralSlice(result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
- std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
+ Literal result_1_literal = ShapedBufferToLiteral(result_1);
LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
- LiteralSlice(*result_1_literal, {0}));
+ LiteralSlice(result_1_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
- LiteralSlice(*result_1_literal, {1}));
+ LiteralSlice(result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
// Feed in a tuple where each two-element vector element is {tuple_index,
// -tuple_index}.
- std::vector<std::unique_ptr<Literal>> arg_elements;
+ std::vector<Literal> arg_elements;
for (int i = 0; i < kElementCount; ++i) {
arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
}
- std::unique_ptr<Literal> arg_literal =
- LiteralUtil::MakeTupleOwned(std::move(arg_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
+ {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
}
}
@@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::vector<std::unique_ptr<Literal>> outer_tuple_elements;
+ std::vector<Literal> outer_tuple_elements;
for (int i = 0; i < kFanout; ++i) {
- std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
+ std::vector<Literal> inner_tuple_elements;
for (int j = 0; j < kFanout; ++j) {
inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
}
@@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
}
auto arg_literal =
LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
- LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
- error_spec_);
+ LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
+ LiteralSlice(result_literal, {i, j}),
+ error_spec_);
}
}
}
@@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
+ Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
for (int i = 0; i < kTupleDepth; ++i) {
- std::vector<std::unique_ptr<Literal>> arg_vector;
+ std::vector<Literal> arg_vector;
arg_vector.push_back(std::move(arg_literal));
arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
}
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
ShapeIndex index;
for (int i = 0; i < kTupleDepth; ++i) {
index.push_back(0);
}
LiteralTestUtil::ExpectR0Equal<float>(165.0,
- LiteralSlice(*result_literal, index));
+ LiteralSlice(result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {&x_array},
DefaultExecutableBuildOptions().set_result_layout(
@@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
DefaultExecutableRunOptions().set_device_ordinal(d));
EXPECT_EQ(d, result.device_ordinal());
LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ ShapedBufferToLiteral(result));
}
}
}
@@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
// As a check to verify that the computation ran of the device associated
// with the stream. This is a weak check, but stronger verification is hard.
EXPECT_EQ(d, result.device_ordinal());
- LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
}
}
@@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
+ Literal tuple_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
- LiteralSlice(*tuple_literal, {0}));
+ LiteralSlice(tuple_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
- LiteralSlice(*tuple_literal, {1}));
+ LiteralSlice(tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
@@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
executable_status.ConsumeValueOrDie();
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
->BlockHostUntilDone());
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
@@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
// Array shapes.
- test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
- test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
- test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
+ test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
+ test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
// Null shape (empty tuple).
- test_to_device_and_back(*LiteralUtil::MakeTuple({}));
+ test_to_device_and_back(LiteralUtil::MakeTuple({}));
// Non-nested tuples.
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(12223.0)}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}));
// Nested tuple.
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()})
- .get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
test_to_device_and_back(
- *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
+ LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<double>({1.0, -42.0}),
+ LiteralUtil::CreateR0<int64>(123456789000LL)}));
}
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
@@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
Add(in, constant);
- std::unique_ptr<Literal> result;
+ Literal result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
}));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
// Join the thread.
thread.reset();
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
@@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
[&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
local_client_->TransferFromOutfeedLocal(
shape, local_client_->default_device_ordinal()));
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
// Benchmark that measures the overhead of the LocalClient API when running a
@@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) {
auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
- ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
- buffer));
+ ASSERT_IS_OK(
+ transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
const int kWarmups = 2;
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index a8c68fc7fd..f90ef22d2d 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
.ConsumeValueOrDie();
}
-std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+Literal LocalClientTestBase::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
return local_client_->ShapedBufferToLiteral(shaped_buffer)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 90095c5d41..4027c7b124 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -86,8 +86,7 @@ class LocalClientTestBase : public ::testing::Test {
// Construct and return a literal containing the array represented by
// shaped_buffer.
- std::unique_ptr<Literal> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Execute the given computation on the local client. With and without
// options.
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 0732e195d4..4d327a6fe9 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
@@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
@@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
@@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) {
TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
@@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
@@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
{0});
@@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1});
@@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1, 2});
@@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) {
TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param2_literal =
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
+ Literal param2_literal =
LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
- client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param2_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
- auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
@@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
Add(x, y);
auto error_add = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, error_add, {0});
StatusOr<XlaComputation> computation_status = builder.Build();
@@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
@@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
@@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index edb592f43e..3f278115e0 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
});
Exp(data);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
{0.36788f, 1.64872f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
@@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
});
Map(&builder, {data}, add_half, {0, 1});
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
{-0.5f, 1.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
@@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
});
Max(lhs, rhs);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
{3.0f, -4.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6));
}
struct TestLinspaceMaxParam {
@@ -200,14 +200,12 @@ class MatOpsDotAddTest
TF_ASSERT_OK_AND_ASSIGN(
auto lhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
TF_ASSERT_OK_AND_ASSIGN(
auto rhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
XlaBuilder builder(TestName());
auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index c5e0b9b097..56aaeb0e68 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -114,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
+ Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
auto actual =
- ExecuteAndTransfer(std::move(hlo_module),
- {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
void RunTest1D(bool manual_fusion, int size) {
@@ -178,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase {
Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect =
- std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
};
@@ -218,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
LiteralUtil::CreateR0<float>(1.0)),
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
LiteralUtil::CreateR0<int32>(4)));
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -247,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
@@ -280,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
}
const char* const kScalarOps = R"(
@@ -324,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -356,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -389,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
- LiteralUtil::CreateR1<float>({36, 64}),
- LiteralUtil::CreateR1<float>({66, 138})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -422,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -456,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -492,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({14, 22}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -530,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest,
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
auto init1 = LiteralUtil::CreateR0<float>(5);
auto init2 = LiteralUtil::CreateR0<float>(6);
- std::unique_ptr<Literal> result = ExecuteNoHloPasses(
- std::move(module), {param.get(), init1.get(), init2.get()});
+ Literal result =
+ ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -565,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest,
auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
LiteralUtil::CreateR3<Eigen::half>(
@@ -576,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest,
{Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)},
{Eigen::half(7), Eigen::half(8)}}})),
- *result));
+ result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
index 0a0426adcb..f2460822a6 100644
--- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
+++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
@@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
GetTupleElement(result_tuple, 0);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
VLOG(1) << "Transferring trip count to computation";
// Transfer number of iterations to Infeed.
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<int32_t>(1)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
// Pick up value from outfeed
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 1);
+ EXPECT_EQ(r.Get<int32>({}), 1);
}
VLOG(1) << "Writing data to infeed";
// Transfer some stuff to Infeed for use inside of loop.
TF_ASSERT_OK(local_client_->TransferToInfeed(
- *LiteralUtil::CreateR1<int32_t>({10, 20})));
+ LiteralUtil::CreateR1<int32_t>({10, 20})));
// Pick up value from outfeed
{
VLOG(1) << "Reading from body outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&xfeed_shape));
- EXPECT_EQ(r->Get<int32>({0}), 11);
- EXPECT_EQ(r->Get<int32>({1}), 21);
+ EXPECT_EQ(r.Get<int32>({0}), 11);
+ EXPECT_EQ(r.Get<int32>({1}), 21);
}
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 0);
+ EXPECT_EQ(r.Get<int32>({}), 0);
}
// Joins the thread
thread.reset();
- EXPECT_EQ(comp_result->Get<int32>({}), 0);
+ EXPECT_EQ(comp_result.Get<int32>({}), 0);
}
XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
@@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
}));
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&result_shape));
- EXPECT_EQ(r->Get<bool>({}), true);
+ EXPECT_EQ(r.Get<bool>({}), true);
// Join the thread
thread.reset();
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index cbeddffacf..6e98167739 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(0);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
}
@@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
dimension->set_edge_padding_high(4);
dimension->set_interior_padding(7);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
DefaultErrorSpec());
}
@@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(1);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
}
@@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
XlaBuilder b(TestName());
Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
- AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
DefaultErrorSpec());
@@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
});
input->FillWithYX(input_xy);
- Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
auto expected = absl::make_unique<Array4D<float>>(2, 3, 3, 2);
@@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
Pad(AddParam(input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b),
r4_padding_on_dim0_dim1_);
auto expected = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
@@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
const float pad_value = -5.123f;
Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 1, 5, 8);
expected_array.Fill(pad_value);
@@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
input_array(0, 24, 6, 6) = 2.0f;
input_array(0, 17, 2, 5) = 3.0f;
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 25, 17, 11);
expected_array.Fill(pad_value);
@@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
- padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(3.14f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) {
XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
- Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
+ Reduce(input, AddParam(LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
- padding_config);
+ Pad(reduce, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index f6c762e7a4..dcb4c11c3c 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR0<float>(3.14159f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
@@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
@@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
@@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+ Literal param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0,
ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
@@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
@@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
@@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
// Use both parameters
//
@@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+ Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
@@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ Parameter(&builder, 1, literal1.shape(), "param1");
ComputeAndCompareR1<float>(&builder, {10, 20},
{param0_data.get(), param1_data.get()},
@@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 =
- LiteralUtil::CreateR1<float>({10, 20, 30});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&builder, 2, literal1->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
// This add is unused.
Add(param1, param2);
@@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+ Literal literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
- client_->TransferToServer(*literal).ConsumeValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ client_->TransferToServer(literal).ConsumeValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest,
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
+ Literal literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
sum_handle = Add(sum_handle, param);
}
@@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
- parameter_shapes.push_back(literal->shape());
+ parameter_shapes.push_back(literal.shape());
}
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
+ Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
+ std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
XlaOp bool_param =
- Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param");
+ Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
params.push_back(bool_param);
- parameter_shapes.push_back(bool_literal->shape());
+ parameter_shapes.push_back(bool_literal.shape());
auto init = Tuple(&builder, params);
@@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest,
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR1<float>({4, 5, 6}),
}))
.ConsumeValueOrDie();
@@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ Literal literal = LiteralUtil::CreateR2<float>({
{1, 3},
{2, 4},
});
- const Shape original = literal->shape();
+ const Shape original = literal.shape();
{
// Reverse the layout present in original, and make that the layout of the
// literal.
@@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
original.layout().minor_to_major().begin(),
original.layout().minor_to_major().end());
std::reverse(original_layout.begin(), original_layout.end());
- *literal->mutable_shape_do_not_use()->mutable_layout() =
+ *literal.mutable_shape_do_not_use()->mutable_layout() =
LayoutUtil::MakeLayout(original_layout);
- ASSERT_EQ(2, literal->Get<float>({0, 1}));
+ ASSERT_EQ(2, literal.Get<float>({0, 1}));
}
// Use the original shape in building the computation.
XlaBuilder builder(TestName());
@@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
Slice(input, {0, 1}, {1, 2}, {1, 1});
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
// Check that we got the off-diagonal value that we expected.
Array2D<float> expected(1, 1);
expected(0, 0) = 2;
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 5f322b768d..8f2c26f0ee 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -37,8 +37,7 @@ namespace {
class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
- std::unique_ptr<Literal> UniformTest(T a, T b, absl::Span<const int64> dims,
- int64 seed = 42);
+ Literal UniformTest(T a, T b, absl::Span<const int64> dims, int64 seed = 42);
// Computes the χ² statistic of a sample of the discrete uniform distribution
// of the given range size. `expected_count` is the number of times each
@@ -49,9 +48,8 @@ class PrngTest : public ClientLibraryTestBase {
};
template <typename T>
-std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
- absl::Span<const int64> dims,
- int64 seed) {
+Literal PrngTest::UniformTest(T a, T b, absl::Span<const int64> dims,
+ int64 seed) {
XlaBuilder builder(TestName());
RngUniform(
ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
@@ -60,8 +58,8 @@ std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
SetSeed(seed);
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
- EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
- actual->EachCell<T>([=](absl::Span<const int64>, T value) {
+ EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions()));
+ actual.EachCell<T>([=](absl::Span<const int64>, T value) {
EXPECT_LE(a, value);
EXPECT_LT(value, b);
});
@@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
constexpr int64 count = 100;
for (int64 seed = 0; seed < count; ++seed) {
auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
- result->Literal::EachCell<bfloat16>(
- [&](absl::Span<const int64>, bfloat16 value) {
- int64 index = static_cast<int64>((value - low) / interval);
- counts[index]++;
- });
+ result.EachCell<bfloat16>([&](absl::Span<const int64>, bfloat16 value) {
+ int64 index = static_cast<int64>((value - low) / interval);
+ counts[index]++;
+ });
}
// Each bucket should have similar amount of counts. That is, not more than
// 10% of total counts. This mostly tests that we don't fall into a 1:2:2
@@ -149,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
std::vector<int32> counts(range_size, 0);
- actual->EachCell<int32>(
+ actual.EachCell<int32>(
[&counts](absl::Span<const int64>, int32 value) { ++counts[value]; });
int64 sum = 0;
for (int32 i = 0; i < range_size; ++i) {
@@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
};
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
- client_->TransferToServer(*param0_literal));
+ client_->TransferToServer(param0_literal));
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto fn = build_sum_rng(builder);
Map(&builder, {param0}, fn, {0});
@@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
computation,
/*arguments=*/{param0_data.get()}, &execution_options));
- EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
- ShapeUtil::ElementsIn(param0_literal->shape()));
- for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
- EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
- EXPECT_LT(actual->data<float>()[i],
- param0_literal->data<float>()[i] + 1.0f);
+ EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()),
+ ShapeUtil::ElementsIn(param0_literal.shape()));
+ for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) {
+ EXPECT_GE(actual.data<float>()[i], param0_literal.data<float>()[i]);
+ EXPECT_LT(actual.data<float>()[i], param0_literal.data<float>()[i] + 1.0f);
}
}
@@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
ExecutionOptions execution_options2 = execution_options_;
execution_options2.set_seed(65);
- std::unique_ptr<Literal> result1;
+ Literal result1;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
&execution_options1));
}
- std::unique_ptr<Literal> result2;
- std::unique_ptr<Literal> result3;
+ Literal result2;
+ Literal result3;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options1));
}
- std::unique_ptr<Literal> result4;
- std::unique_ptr<Literal> result5;
- std::unique_ptr<Literal> result6;
+ Literal result4;
+ Literal result5;
+ Literal result6;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options_));
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index 9af9ea4a22..c9096fb29b 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
*reduce_input_shape->mutable_layout() =
LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
- std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+ Literal reduce_input = LiteralUtil::CreateR4<float>(
{{ /*i0=0*/
{/*i1=0*/
{-0.246092796, -0.179497838, -0.161181688},
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 0916a07f4f..26e2bfde5c 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR1<float>({input_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
ReducePrecision(a, exponent_bits, mantissa_bits);
@@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// Abs doesn't affect resolution.
auto abs = Abs(a);
@@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 57f7fed61f..83997cdac2 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase {
}, 4);
// clang-format on
CHECK(ShapeUtil::Equal(
- literal_3d_->shape(),
+ literal_3d_.shape(),
ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
- << literal_3d_->shape().ShortDebugString();
+ << literal_3d_.shape().ShortDebugString();
}
// Runs an R1 => R0 reduction test with the given number of elements.
@@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase {
input_data[i] *= -1;
}
}
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1(AsSlice(input_data));
+ Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (float item : input_data) {
@@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase {
Reduce(pred_values, init_value, reduce,
/*dimensions_to_reduce=*/{0});
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
+ Literal input_literal = LiteralUtil::CreateR1(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
bool expected = and_reduce;
for (bool item : input_data) {
@@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<uint8> input_data(rows, cols);
input_data.FillRandom(0, 1);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::array<bool, cols> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (int64 rowno = 0; rowno < rows; ++rowno) {
@@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<NativeT> input_data(rows, cols);
input_data.FillUnique(initial_value);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
// NativeT can be bool, and std::vector<bool> does not convert to
// Span.
@@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase {
reference_reduction_function_for_uints, unsigned_int_identity);
}
- std::unique_ptr<Literal> literal_2d_;
- std::unique_ptr<Literal> literal_3d_;
+ Literal literal_2d_;
+ Literal literal_3d_;
uint32 seed_ = 0xdeadbeef;
};
@@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
- MakeFakeLiteral(input_shape));
+ TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
- ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4));
+ ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
}
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
@@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3D(input_data);
+ Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 major = 0; major < 2; ++major) {
@@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
auto input_max = FLT_MIN;
input.Each(
@@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
auto input_min = FLT_MAX;
@@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 1, {});
}
@@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
// Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XlaBuilder builder("reduce_among_y");
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
@@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
@@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
@@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
+ input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
XlaComputation add = CreateScalarAddComputation(F32, &builder);
Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
GetParam().reduce_dims);
@@ -873,11 +864,10 @@ XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
auto a = ConstantR0<float>(&builder, 2.0f);
auto a2 = Abs(a);
- std::unique_ptr<Literal> b_literal =
- LiteralUtil::CreateR1<float>({1.0f, 4.0f});
+ Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b = Parameter(&builder, 0, b_literal->shape(), "b");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b = Parameter(&builder, 0, b_literal.shape(), "b");
Reduce(b, a2, max_f32, {0});
ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
@@ -904,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest {
std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
auto input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
- max_fn, {0});
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
+ {0});
ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
}
@@ -952,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) {
float operand[] = {42.0f};
float init = 58.5f;
float expected = 42.0f;
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1<float>(operand);
+ Literal input_literal = LiteralUtil::CreateR1<float>(operand);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
std::unique_ptr<GlobalData> input_global_data2 =
- client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal2).ConsumeValueOrDie();
ComputeAndCompareR0<float>(
&builder, expected, {input_global_data.get(), input_global_data2.get()},
ErrorSpec(0.0001));
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index a1001296a1..d5de9650f1 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -73,7 +73,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+ auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
&builder_);
ReduceWindow(input, init,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
TF_ASSERT_OK(builder_.first_error());
ReduceWindow(input, init_value,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
/*window_dimensions=*/{},
/*window_strides=*/{}, Padding::kSame);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
{}, ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Padding::kSame);
ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+ LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
{}, ErrorSpec(0.00001));
}
@@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
{1, 2, 2, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests a reduction function that is not a simple add/min/max/etc.
@@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
Min(Add(lhs, rhs),
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
XlaComputation reduce_fn = b->BuildAndNoteError();
ReduceWindow(
input,
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
reduce_fn,
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
@@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
/*window=*/{1, 1, 2, 1},
/*stride=*/{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
{}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
Padding padding = Padding::kSame;
ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
@@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
{1, 4, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- auto arg_literal = absl::make_unique<Literal>(shape);
- arg_literal->PopulateWithValue(1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ Literal arg_literal(shape);
+ arg_literal.PopulateWithValue(1.0f);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
@@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
Shape result_shape =
ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
- auto expected = absl::make_unique<Literal>(result_shape);
- expected->PopulateWithValue(27.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ Literal expected(result_shape);
+ expected.PopulateWithValue(27.0f);
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 1;
int stride = 8;
@@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
Array4D<float> input_array(3, 2, 4, 64);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 3;
int stride = 1;
@@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
Array4D<float> input_array(1, 3, 12, 200);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 8;
int stride = 5;
@@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
std::vector<float> input_vector(128 * 9, 1);
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
ComputeAndCompareLiteral(
&builder_,
- *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+ LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
DefaultErrorSpec());
}
@@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(
input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
@@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -614,11 +608,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], param.base_bounds[3]);
input.FillRandom(0.1f, 0.1f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(4);
@@ -627,7 +620,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
auto reducer = param.reducer;
if (use_bfloat16() && Product(param.window_bounds) > 128) {
@@ -659,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides,
/*padding=*/padding);
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateFromArray(*expected);
+ Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
- input_literal->shape().element_type(),
- AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
- ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
+ input_literal.shape().element_type(),
+ AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
+ ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
DefaultErrorSpec(), &expected_shape_with_layout);
}
};
@@ -1008,12 +1000,11 @@ TEST_P(R3ReduceWindowTest, DoIt) {
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2]);
input.FillRandom(0.1f, 0.1f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
auto reducer = param.reducer;
if (use_bfloat16()) {
- input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal);
+ input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
if (Product(param.window_bounds) > 128) {
// To avoid numerical issues, force the reducer to be kMax for large bf16
// windows.
@@ -1021,9 +1012,9 @@ TEST_P(R3ReduceWindowTest, DoIt) {
}
}
- XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input");
+ XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
auto computation = reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
@@ -1035,7 +1026,7 @@ TEST_P(R3ReduceWindowTest, DoIt) {
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides, /*padding=*/param.padding);
- ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec());
+ ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(
@@ -1147,12 +1138,11 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(2);
for (int i = 0; i < 2; ++i) {
@@ -1162,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1178,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
};
@@ -1352,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) {
const float kInitValue = 0.0f;
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
+ auto input_arg =
+ CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
std::vector<std::pair<int64, int64>> padding(1);
padding[0] = {param.pad_low[0], param.pad_high[0]};
@@ -1365,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1384,7 +1374,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index d891451381..5cf87e565b 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect 4.
- LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(4, literal);
}
XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
@@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(2))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(3))
.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed,
/*arguments=*/{x_data.get(), y_data.get()},
@@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
.ConsumeValueOrDie();
// Expect 5.
- LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(5, literal);
}
TEST_F(ReplayTest, MapPlusTwoOverR1) {
@@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect result.
- LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+ LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, literal);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 17d12715f6..dedc95b5ae 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{});
auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(1.0f);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
auto a = Neg(parameter);
Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) {
Array2D<float> input_array(0, 3);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) {
Array2D<float> input_array(3, 0);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
auto input_literal =
LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 3});
auto expected_literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
auto input_literal = LiteralUtil::CreateFromArray(*simple);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{3, 1});
auto expected = ReferenceUtil::TransposeArray2D(*simple);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 4});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 3, 0, 0});
auto expected_literal =
LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{24, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 6});
auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{2, 6});
Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
{8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
auto expected_literal = LiteralUtil::CreateFromArray(expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{8, 3});
@@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
{35, 36, 37},
{40, 41, 42},
{45, 46, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{8, 3});
@@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
{45, 16, 26},
{36, 46, 17},
{27, 37, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{2, 6, 2});
auto expected_literal = LiteralUtil::CreateR3<float>(
{{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
{{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
t2x2x2x3.FillWithYX(*filler2x3);
auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
auto expected_literal = LiteralUtil::CreateR2<float>(
{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 1, 1) = 7;
auto input_literal = LiteralUtil::CreateFromArray(t);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{2, 4});
auto expected_literal =
LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
Reshape(parameter, dimensions, {});
auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
- ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&b, expected_literal, {input.get()},
zero_error_spec_);
}
}
@@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {}, {});
EXPECT_THAT(
@@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {1}, {});
EXPECT_THAT(ExecuteToString(&b, {}),
@@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
LayoutUtil::MakeLayout({0, 1, 2, 3}));
// clang-format on
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
@@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
{1, 0});
- std::unique_ptr<Literal> actual =
+ Literal actual =
client_
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+ Literal expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = LiteralUtil::ConvertF32ToBF16(*expected);
+ expected = LiteralUtil::ConvertF32ToBF16(expected);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
@@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
{{204, 205, 206, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
@@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
{{206, 7, 107, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -691,17 +690,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
Array4D<float> input(2, 1, 1, 1);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -712,17 +709,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
Array4D<float> input(2, 1, 4, 1);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -734,12 +729,11 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
Array4D<float> input(5, 10, 2, 3);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
/*new_sizes=*/{5, 60});
@@ -749,7 +743,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
*cell;
});
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -761,12 +755,11 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
input_array.Each(
[&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
/*new_sizes=*/{7, 2, 3, 5});
XlaComputation computation = builder.Build().ConsumeValueOrDie();
@@ -775,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
{2, 3, 0, 1});
- std::unique_ptr<Literal> output_literal =
+ Literal output_literal =
client_
->ExecuteAndTransfer(computation, {input_data.get()},
&execution_options)
@@ -784,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
- EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
+ auto expected = LiteralUtil::ConvertF32ToBF16(input_literal);
+ EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
} else {
- EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
+ EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
}
}
@@ -798,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()});
+ ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
@@ -813,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
/*new_sizes=*/{2, 4, 3, 1});
@@ -830,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
{{16}, {20}, {24}}}});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()});
+ ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
@@ -841,24 +834,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
@@ -869,24 +861,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
@@ -897,24 +888,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
@@ -926,24 +916,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
@@ -954,24 +943,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
- ->Relayout(input_literal->shape().layout());
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal)
+ .Relayout(input_literal.shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 74ded82ddf..4e55b0d7ac 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -83,25 +83,25 @@ TEST_P(FloatReverseTest, Reverses) {
ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
std::iota(input_vector.begin(), input_vector.end(), 0.0);
auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
- auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
+ auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto a = AddParam(*input_literal, &builder);
+ auto a = AddParam(input_literal, &builder);
Rev(a, spec.reversal);
- std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
+ Literal expected = input_literal.Clone();
std::vector<int64> output_indices(spec.input_dims.size());
- expected->EachCell<float>([&](absl::Span<const int64> indices, float) {
+ expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
for (int64 i = 0; i < indices.size(); ++i) {
output_indices[i] = indices[i];
}
- float value = input_literal->Get<float>(indices);
+ float value = input_literal.Get<float>(indices);
for (int64 dim : spec.reversal) {
output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
}
- expected->Set<float>(output_indices, value);
+ expected.Set<float>(output_indices, value);
});
- ComputeAndCompareLiteral(&builder, *expected, {});
+ ComputeAndCompareLiteral(&builder, expected, {});
}
INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index e692b8c5d5..091a5d2cac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -38,7 +38,7 @@ namespace {
class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
protected:
// Sends the literal to the server and retrieves it back.
- std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ Literal RoundTripToServer(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
return client_->Transfer(*data).ConsumeValueOrDie();
@@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
+ Literal actual =
reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0, actual->Get<float>({0}));
- EXPECT_EQ(24.0, actual->Get<float>({1}));
+ EXPECT_EQ(42.0, actual.Get<float>({0}));
+ EXPECT_EQ(24.0, actual.Get<float>({1}));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
@@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(64.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(64.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(64.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(64.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index a8193c2eac..cd5a531603 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
void RoundTripTest(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data).ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
+ Literal result = client_->Transfer(*data).ConsumeValueOrDie();
+ EXPECT_TRUE(LiteralTestUtil::Equal(original, result));
}
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+ RoundTripTest(LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+ RoundTripTest(LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(
- *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*LiteralUtil::CreateR4<float>({{
+ RoundTripTest(LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*LiteralUtil::MakeTuple({}));
+ RoundTripTest(LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
- LiteralUtil::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(1.0), LiteralUtil::CreateR1<int>({2, 3})}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 07460a7e01..1dd937a6d0 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -161,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
+ Literal a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
{a_data.get()});
}
@@ -225,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+ Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> c_data =
- client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+ client_->TransferToServer(c_literal).ConsumeValueOrDie();
- XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a");
- XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b");
- XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c");
+ XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
+ XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
+ XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
Mul(Mul(a, b), c);
ComputeAndCompareR0<float>(&builder, 5.775f,
@@ -377,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(div_computation,
@@ -388,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend / divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -419,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(rem_computation,
@@ -430,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend % divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -441,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
- TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
+ Literal literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index 1858dcea61..d20dba028a 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -62,13 +62,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
@@ -92,13 +90,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
@@ -123,13 +120,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
@@ -154,13 +149,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
@@ -185,13 +178,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
@@ -216,13 +208,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
@@ -247,13 +237,12 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
@@ -277,15 +266,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
@@ -309,15 +296,13 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
@@ -341,12 +326,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
@@ -370,13 +354,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ZeroDimBounds) {
@@ -400,11 +382,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
@@ -429,12 +410,11 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
@@ -458,13 +438,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
@@ -488,13 +468,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NegativeIndex) {
@@ -518,13 +498,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OneScalarIndex) {
@@ -548,12 +528,12 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates =
LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScalarUpdate) {
@@ -577,10 +557,10 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, EmptyIndices) {
@@ -604,10 +584,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
+ Literal updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index c9a58aefb4..a40c2d7de6 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
XlaBuilder builder(TestName());
auto original = ConstantR4FromArray4D(&builder, values);
Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
- &expected_literal->shape());
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
+ &expected_literal.shape());
}
struct R1Spec {
@@ -201,7 +201,7 @@ class SliceR1Test : public ClientLibraryTestBase,
auto literal = LiteralUtil::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
- auto original = Parameter(&builder, 0, literal->shape(), "p0");
+ auto original = Parameter(&builder, 0, literal.shape(), "p0");
Slice(original, {spec.slice_start}, {spec.slice_limit},
{spec.slice_stride});
@@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase,
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
}
};
@@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) {
input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
- auto a = Parameter(&builder, 0, literal->shape(), "p0");
+ auto a = Parameter(&builder, 0, literal.shape(), "p0");
Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
@@ -467,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase,
XlaBuilder builder(TestName());
auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
- auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
+ auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 3ae31191a0..5155f0c652 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
// are not supported and uniqueness cannot be guaranteed if the number of
// elements exceeds the number of different values supported by the type.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
+StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> element,
+ Literal element,
MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
@@ -131,60 +132,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<half>(&literal, engine,
no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<float>(&literal, engine,
no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<double>(&literal, engine,
no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(
- literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) {
+ literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
return generator(*engine);
}));
break;
@@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space,
- std::minstd_rand0* engine) {
+Literal MakeRandomIndex(absl::Span<const int64> index_space,
+ std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
@@ -293,7 +286,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
// no constrained uses in the dataflow graph. If such constraints exist,
// generate a constrained literal (either bounded in the case of indices, or
// zero in the case of init_values for reductions).
-StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
+StatusOr<Literal> CreateLiteralForConstrainedUses(
const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
@@ -358,9 +351,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
- return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type());
case ConstantType::kOne:
- return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type());
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
@@ -374,34 +367,33 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
// Given a module entry parameter, use the dataflow analysis to see if a
// special case literal must be created, or if we can generate fake data.
-StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
- const HloDataflowAnalysis& dataflow, const HloInstruction& param,
- std::minstd_rand0* engine) {
+StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
+ const HloInstruction& param,
+ std::minstd_rand0* engine) {
const auto constrained_uses = FindConstrainedUses(dataflow, param);
return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
}
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random) {
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeArguments(module, engine.get());
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- std::vector<std::unique_ptr<Literal>> arguments(params.size());
+ std::vector<Literal> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
arguments[i] =
MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index a260271b1b..b3c8a73905 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -57,8 +57,8 @@ class PseudorandomGenerator {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data
// generation. See below for documentation of pseudo_random.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random = true);
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
+ bool pseudo_random = true);
// Generates a vector of arguments containing fake data. The number, shape and
// layout of the arguments is appropriate for given HLO module.
@@ -84,14 +84,14 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
// TODO(b/79942829): Make interesting argument generation fast enough that using
// pseudo_random does not save any noticeable amount of time so that the
// parameter can be removed.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random = true);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random = true);
// Overload which accepts a random number generator. This enables generation of
// different random values with sequential calls to MakeFakeArguments by reusing
// the same generator.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine);
// Check that a given module satisfies various constraints before trying to
// execute it.
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 322c8ef090..181e5cbe29 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -85,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -114,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -140,10 +140,10 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
}
)")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 2);
- const Literal& key_arg = *args[0];
+ const Literal& key_arg = args[0];
tensorflow::gtl::FlatSet<uint32> key_set;
for (const float& value : key_arg.data<float>()) {
@@ -163,10 +163,10 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
}
)")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 2);
- const Literal& key_arg = *args[0];
+ const Literal& key_arg = args[0];
tensorflow::gtl::FlatSet<int32> key_set;
for (const int32& value : key_arg.data<int32>()) {
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index c7eb9e2dbe..b34fd0f2e8 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, TokenTree) {
@@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(true);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(42, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(42, result.Get<int32>({}));
}
{
@@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(false);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(7, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(7, result.Get<int32>({}));
}
}
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 125513ddfd..d6641d257a 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase {
};
XLA_TEST_F(TransferManagerTest, TransferR0U32) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR0<uint32>(42);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
+ LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
- *result);
+ result);
}
XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
std::vector<float> test_vector(1024 * 1024);
std::iota(test_vector.begin(), test_vector.end(), 0);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1<float>(test_vector);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
+ LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1U8) {
const char* test_string = "0123456789abcdef";
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1U8(test_string);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_EQ(result->GetR1U8AsString(), test_string);
+ EXPECT_EQ(result.GetR1U8AsString(), test_string);
}
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest,
TransferR2F32AndChangeLayoutTransferringToDevice) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
const Shape ondevice_shape =
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest,
// Round trip literal through device. Set the on-device layout to something
// different than the literal layout.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_FALSE(
- LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
+ LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest, TransferTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTuple({});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
+ Literal literal = LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<complex64>(
- {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
- .get(),
- LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
- LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
+ LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
@@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
// supported.
auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
}
XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
const int64 kIterationCount = 5000;
- std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(456.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
-
- auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
- auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+ Literal literal1 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ Literal literal2 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(456.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
+
+ auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
+ auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
auto stream1 = stream_;
auto stream2 = stream_->GetOrCreateSubStream();
- std::unique_ptr<Literal> result1, result2;
+ Literal result1, result2;
// Round trip literals through device in multiple streams asynchronously.
for (int i = 0; i < kIterationCount; ++i) {
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
device_buffer1));
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
device_buffer2));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result1,
+ Literal this_result1,
transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result2,
+ Literal this_result2,
transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
result1 = std::move(this_result1);
result2 = std::move(this_result2);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
}
class TransferDeviceToHostBenchmark : public TransferManagerTest {
@@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
}
tensorflow::testing::StopTiming();
@@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
}
tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index f2b3b49015..619d2a388b 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
@@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
- LiteralUtil::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar1),
+ LiteralUtil::CreateR0<float>(constant_scalar2)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) {
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
- LiteralUtil::CreateR1<float>({}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
@@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
auto expected = LiteralUtil::MakeTuple({});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
@@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>(constant_matrix).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>(constant_matrix),
+ LiteralUtil::CreateR1<float>(constant_vector)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
@@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
- LiteralUtil::CreateR0<bool>(!direction).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(direction),
+ LiteralUtil::CreateR0<bool>(!direction)});
- ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
+ ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
@@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, TuplesInAMap) {
@@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
- LiteralUtil::CreateR1<float>(vec2).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
@@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
@@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) {
auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ LiteralUtil::MakeTuple({&expected_v1, &expected_s});
auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
- auto expected =
- LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
@@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::MakeTuple(
- {
- LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
- })
- .get(),
- LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
+ }),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
}))
.ConsumeValueOrDie();
@@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
- .get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<complex64>({1, 2}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
- {{10000, 20000}, {30000, 40000}}})
- .get()})
- .get()}))
+ {{10000, 20000}, {30000, 40000}}})})}))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
->TransferToServer(
- *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
auto sum =
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
- auto prod = absl::make_unique<Literal>(sum->shape());
- ASSERT_TRUE(prod->Populate<complex64>(
- [&sum](absl::Span<const int64> indexes) {
- return sum->Get<complex64>(indexes) *
- (indexes[indexes.size() - 1] == 0
- ? complex64(1, 2)
- : complex64(1, -2));
- })
+ Literal prod(sum.shape());
+ ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
+ return sum.Get<complex64>(indexes) *
+ (indexes[indexes.size() - 1] == 0
+ ? complex64(1, 2)
+ : complex64(1, -2));
+ })
.ok());
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
- LiteralUtil::CreateR0<complex64>({123, 456}).get()});
- ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices({prod, sum}),
+ LiteralUtil::CreateR0<complex64>({123, 456})});
+ ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
.ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
- auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
+ auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+ result));
}
// Disabled on interpreter due to lack of outfeed.
@@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest,
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
TF_EXPECT_OK(Execute(std::move(module),
- {param0.get(), param1.get(), param1.get(),
- param0.get(), param4.get()})
+ {&param0, &param1, &param1, &param0, &param4})
.status());
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
- auto literal = Literal::CreateFromShape(expected->shape());
+ auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- backend().default_stream_executor(), expected->shape(), *literal));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
+ backend().default_stream_executor(), expected.shape(), literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 8f80a9f3e4..4fbd7f2fb1 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper<complex64>() {
{-inf<float>(), 0}});
Abs(arg);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper<complex64>() {
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
Sign(arg);
- std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
+ Literal expected = LiteralUtil::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper<complex64>() {
auto abs = Abs(arg);
Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
@@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) {
Add(sgnc, ConvertElementType(
Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, SignTestR1) {
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf1867b9..7abd8651d5 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// have all reached 2.0.
auto expected_data =
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = LiteralUtil::MakeTuple({expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
@@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple(
+ {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
@@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
@@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
- auto expected = LiteralUtil::MakeTuple(
- {expected_counter.get(), expected_predicate.get()});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
+ auto expected =
+ LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
}
TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
@@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR0<int32>(7);
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests two while nodes when the result type T is a Tuple and the second
@@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests a while node when the result type T is a vector of S32.
@@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({&expected_element, &expected_element});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
auto add1 = LiteralUtil::CreateR0<int32>(15);
auto add2 = LiteralUtil::CreateR0<int32>(16);
- auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ auto expected = LiteralUtil::MakeTuple({&add1, &add2});
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
GetTupleElement(while_instruction, 3);
TF_ASSERT_OK_AND_ASSIGN(
- auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
{{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
@@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
XlaBuilder builder(TestName());
While(condition, body, ConstantR0<int32>(&builder, 0));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
ComputeAndCompareR0<int32>(&builder, 2, {});
}
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7fd42944de..db5a824de0 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
transfer_manager->AllocateScopedShapedBuffer(
lhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer rhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
rhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 442e66321e..cdde88c135 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -39,8 +39,7 @@ limitations under the License.
namespace xla {
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
- absl::string_view path) {
+StatusOr<Literal> TextLiteralReader::ReadPath(absl::string_view path) {
CHECK(!absl::EndsWith(path, ".gz"))
<< "TextLiteralReader no longer supports reading .gz files";
std::unique_ptr<tensorflow::RandomAccessFile> file;
@@ -57,7 +56,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
: file_(file) {}
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+StatusOr<Literal> TextLiteralReader::ReadAllLines() {
tensorflow::io::RandomAccessInputStream stream(file_.get());
tensorflow::io::BufferedInputStream buf(&stream, 65536);
string shape_string;
@@ -74,9 +73,9 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
ShapeUtil::HumanString(shape));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
- result->PopulateWithValue<float>(fill);
+ result.PopulateWithValue<float>(fill);
std::vector<absl::string_view> pieces;
std::vector<absl::string_view> coordinates;
std::vector<int64> coordinate_values;
@@ -116,7 +115,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
"\"%s\"",
shape.dimensions_size(), coordinate_values.size(), line);
}
- result->Set<float>(coordinate_values, value);
+ result.Set<float>(coordinate_values, value);
}
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index b265640802..c40b43279f 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -41,7 +41,7 @@ class TextLiteralReader {
public:
// See class comment -- reads a file in its entirety (there must be only one
// literal in the text file path provided).
- static StatusOr<std::unique_ptr<Literal>> ReadPath(absl::string_view path);
+ static StatusOr<Literal> ReadPath(absl::string_view path);
private:
// Ownership of file is transferred.
@@ -49,7 +49,7 @@ class TextLiteralReader {
// Parses a shape string on the first line, followed by lines of values to the
// end of the file.
- StatusOr<std::unique_ptr<Literal>> ReadAllLines();
+ StatusOr<Literal> ReadAllLines();
// Owns the file being read
std::unique_ptr<tensorflow::RandomAccessFile> file_;
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 92f9b4f9f0..1fab4e3a08 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) {
tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents)
.ok());
- std::unique_ptr<Literal> literal =
- TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
+ Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape()));
- EXPECT_EQ(42.5, literal->Get<float>({0, 0, 0}));
- EXPECT_EQ(43.5, literal->Get<float>({0, 0, 1}));
- EXPECT_EQ(44.5, literal->Get<float>({0, 0, 2}));
- EXPECT_EQ(45.5, literal->Get<float>({0, 1, 0}));
- EXPECT_EQ(46.5, literal->Get<float>({0, 1, 1}));
- EXPECT_EQ(47.5, literal->Get<float>({0, 1, 2}));
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape()));
+ EXPECT_EQ(42.5, literal.Get<float>({0, 0, 0}));
+ EXPECT_EQ(43.5, literal.Get<float>({0, 0, 1}));
+ EXPECT_EQ(44.5, literal.Get<float>({0, 0, 2}));
+ EXPECT_EQ(45.5, literal.Get<float>({0, 1, 0}));
+ EXPECT_EQ(46.5, literal.Get<float>({0, 1, 1}));
+ EXPECT_EQ(47.5, literal.Get<float>({0, 1, 2}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 4ea02faffc..5cbaf2fcc1 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) {
});
string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
- ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path));
+ ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path));
string contents;
TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path,
&contents));
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index ba814af476..0c41f227b3 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -121,11 +121,10 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
}
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
- Literal::CreateFromProto(proto));
+ TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer data,
- client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+ client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
scoped_shaped_buffer_arguments.push_back(std::move(data));
}
for (const auto& argument : scoped_shaped_buffer_arguments) {
@@ -161,12 +160,12 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
// --generate_fake_infeed is passed and there exists an infeed operation in
// the HloSnapshot.
absl::optional<tensorflow::thread::ThreadPool> pool;
- std::unique_ptr<Literal> data;
+ Literal data;
if (provide_infeed) {
data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie();
}
auto transfer_infeed = [&data, client]() {
- TF_CHECK_OK(client->TransferToInfeed(*data));
+ TF_CHECK_OK(client->TransferToInfeed(data));
};
if (provide_infeed) {
pool.emplace(tensorflow::Env::Default(), "infeed",
@@ -214,9 +213,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
<< "s: " << module.hlo().hlo_module().name();
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+ TF_ASSIGN_OR_RETURN(Literal result_literal,
client->ShapedBufferToLiteral(*result));
- return std::move(*result_literal);
+ return result_literal;
}
StatusOr<HloSnapshot> ParseInputFile(const string& filename,
@@ -305,11 +304,11 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
result.ToString().c_str());
auto& snapshot = snapshots[i];
if (snapshot.has_result()) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
- literal->ToString().c_str());
+ literal.ToString().c_str());
}
}
}
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
index 478c9663a7..54b06558ad 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -49,7 +49,7 @@ class XRTStateHelpers {
// TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
// OpKernel::Compute method.
static Status MakeLiteral(const xla::LiteralProto& proto,
- std::unique_ptr<xla::Literal>* literal) {
+ xla::Literal* literal) {
TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
return Status::OK();
}
@@ -173,7 +173,7 @@ class XRTAllocateOp : public OpKernel {
errors::InvalidArgument(
"Unable to parse allocation input to XLAAllocation"));
- std::unique_ptr<xla::Literal> literal;
+ xla::Literal literal;
OP_REQUIRES_OK(
ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
@@ -189,7 +189,7 @@ class XRTAllocateOp : public OpKernel {
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
- *literal, device_ref.backend(),
+ literal, device_ref.backend(),
device_ref.device_ordinal(), &allocation));
// Intern takes ownership of our reference to allocation.
@@ -381,11 +381,11 @@ class XRTReadLiteralOp : public OpKernel {
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
ctx, allocation->device_ordinal(), &device_ref));
- std::unique_ptr<xla::Literal> literal;
+ xla::Literal literal;
OP_REQUIRES_OK(
ctx, allocation->ToLiteral(device_ref.backend(),
device_ref.device_ordinal(), &literal));
- xla::LiteralProto literal_proto = literal->ToProto();
+ xla::LiteralProto literal_proto = literal.ToProto();
Tensor output(DT_STRING, TensorShape({}));
literal_proto.SerializeToString(&output.scalar<string>()());
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 5b8516bf1d..2952feb16a 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -52,44 +52,44 @@ string DeviceFromFlag() {
xla::LiteralProto TwoElementTuple() {
auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
- auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
- return tuple->ToProto();
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ return tuple.ToProto();
}
xla::LiteralProto ScalarLiteral() {
auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
- return scalar->ToProto();
+ return scalar.ToProto();
}
xla::LiteralProto NestedTuple() {
auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
- auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
- auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
- return nested->ToProto();
+ auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
+ return nested.ToProto();
}
xla::LiteralProto MakeTuple0() {
auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
- auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
- auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()});
- auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()});
- return nested1->ToProto();
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
+ auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
+ return nested1.ToProto();
}
-xla::LiteralProto FloatVector(gtl::ArraySlice<float> v) {
+xla::LiteralProto FloatVector(absl::Span<const float> v) {
auto array = xla::LiteralUtil::CreateR1<float>(v);
- return array->ToProto();
+ return array.ToProto();
}
bool CompareLiteralProtos(const xla::LiteralProto& a,
const xla::LiteralProto& b) {
auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
- bool equal = *l_a == *l_b;
+ bool equal = l_a == l_b;
if (!equal) {
LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
<< " != " << b.DebugString();
@@ -100,7 +100,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a,
bool CompareLiteralToLiteralProto(const xla::Literal& a,
const xla::LiteralProto& b) {
auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
- bool equal = a == *l_b;
+ bool equal = a == l_b;
if (!equal) {
LOG(INFO) << "Literal and LiteralProto don't match "
<< a.ToProto().DebugString() << " != " << b.DebugString();
@@ -211,7 +211,7 @@ TEST(RawApiTest, SubBuffer) {
TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
- auto base_elements = base_literal->DecomposeTuple();
+ auto base_elements = base_literal.DecomposeTuple();
auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
xla::LiteralProto response_0;
EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
@@ -343,7 +343,7 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
- EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, CompileAndExecuteReturnTuple) {
@@ -392,8 +392,8 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) {
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
- auto expected = xla::LiteralUtil::MakeTuple({sum.get()});
- EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+ auto expected = xla::LiteralUtil::MakeTuple({&sum});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
} // namespace
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
index 2c3b07da58..d05a1e7dcb 100644
--- a/tensorflow/compiler/xrt/xrt_state.cc
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -174,7 +174,7 @@ XRTTupleAllocation::~XRTTupleAllocation() {
}
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
- std::unique_ptr<xla::Literal>* literal) {
+ xla::Literal* literal) {
auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
index 42705688dd..73b5584e38 100644
--- a/tensorflow/compiler/xrt/xrt_state.h
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -135,7 +135,7 @@ class XRTTupleAllocation : public ResourceBase {
// Copies the allocation from device to host and returns it in literal.
Status ToLiteral(xla::Backend* backend, int device_ordinal,
- std::unique_ptr<xla::Literal>* literal);
+ xla::Literal* literal);
// True if none of the buffers in the allocation are aliased by any other live
// handle.