aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/client.cc12
-rw-r--r--tensorflow/compiler/xla/client/client.h10
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc4
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc20
-rw-r--r--tensorflow/compiler/xla/client/local_client.h10
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc46
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h45
-rw-r--r--tensorflow/compiler/xla/literal.cc149
-rw-r--r--tensorflow/compiler/xla/literal.h58
-rw-r--r--tensorflow/compiler/xla/literal_test.cc913
-rw-r--r--tensorflow/compiler/xla/literal_util.cc273
-rw-r--r--tensorflow/compiler/xla/literal_util.h228
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc15
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h3
-rw-r--r--tensorflow/compiler/xla/protobuf_util.cc29
-rw-r--r--tensorflow/compiler/xla/protobuf_util.h4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc20
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h8
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i18
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc7
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h2
-rw-r--r--tensorflow/compiler/xla/reference_util.cc75
-rw-r--r--tensorflow/compiler/xla/reference_util.h50
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc50
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/BUILD74
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc323
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc12
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc6
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc53
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc129
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc167
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc81
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h44
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc38
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc61
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc253
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h57
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc503
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h203
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling.cc)20
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h (renamed from tensorflow/compiler/xla/service/hlo_scheduling.h)38
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling_test.cc)28
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.h81
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_test.cc142
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc95
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc54
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc88
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h83
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc79
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h25
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h14
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc260
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h34
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc15
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc107
-rw-r--r--tensorflow/compiler/xla/service/service.cc49
-rw-r--r--tensorflow/compiler/xla/service/service.h4
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc19
-rw-r--r--tensorflow/compiler/xla/shape_tree.h9
-rw-r--r--tensorflow/compiler/xla/shape_util.cc13
-rw-r--r--tensorflow/compiler/xla/shape_util.h4
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc256
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc128
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc89
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc71
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h101
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc148
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc60
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc41
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc69
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc117
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc130
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc161
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc23
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h12
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h30
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc43
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc253
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc2
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h3
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc150
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc87
-rw-r--r--tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc46
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc149
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc62
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc123
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc192
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc308
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc172
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc74
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h12
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc204
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc152
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc66
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc11
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h4
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc17
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc2
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc17
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc4
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc16
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
211 files changed, 5585 insertions, 5235 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 76e36f3c46..ef70c1f8ac 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -193,6 +193,7 @@ cc_library(
":types",
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/synchronization",
],
)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 8818f81312..5dde5b432f 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
Client::~Client() = default;
-StatusOr<std::unique_ptr<Literal>> Client::Transfer(
- const GlobalData& data, const Shape* shape_with_layout) {
+StatusOr<Literal> Client::Transfer(const GlobalData& data,
+ const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
@@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
@@ -162,7 +162,7 @@ Status Client::ResetDevice() {
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
+StatusOr<Literal> Client::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
@@ -177,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
return Transfer(*data, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
- const XlaComputation& computation, const Layout* output_layout) const {
+StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
+ const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 7960b07868..6f4d33c469 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -96,8 +96,8 @@ class Client {
//
// If shape_with_layout is not nullptr, it points to a shape whose layout will
// be the layout of the returned literal.
- StatusOr<std::unique_ptr<Literal>> Transfer(
- const GlobalData& data, const Shape* shape_with_layout = nullptr);
+ StatusOr<Literal> Transfer(const GlobalData& data,
+ const Shape* shape_with_layout = nullptr);
// Transfer the given literal to the server. This allocates memory on the
// device and copies the literal's contents over. Returns a global data handle
@@ -122,7 +122,7 @@ class Client {
// device_handle and replica_id together specify a particular device; a device
// assigned for the given replica_id among the replicas that the given device
// handle belongs to.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
+ StatusOr<Literal> TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
@@ -132,7 +132,7 @@ class Client {
// Executes the computation with the given arguments and transfers the result
// to the client as a literal. Parameters are defined the same as for
// Execute() and Transfer().
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
@@ -153,7 +153,7 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ StatusOr<Literal> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 6861521acc..25cc37edc4 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -76,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (DataSizeOfShape(shape) < (1LL << 20)) {
- StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
+ StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via
// an on-device computation.
@@ -84,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
tensorflow::error::UNIMPLEMENTED);
return MakeFakeDataViaDeviceOrDie(shape, client);
}
- return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
+ return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
}
// If the data is large, generate it on-device.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 4402ba8762..f96b6c9c26 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments(
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*argument));
- *hlo_snapshot->add_arguments() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
+ *hlo_snapshot->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments(
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*result));
- *hlo_snapshot->mutable_result() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
+ *hlo_snapshot->mutable_result() = literal.ToProto();
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
+StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream,
backend_->BorrowStream(shaped_buffer.device_ordinal()));
@@ -277,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
return std::move(scoped_buffer);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
+StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
shaped_buffer.device_ordinal()));
@@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
literal);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal) {
+StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, shape, literal.get()));
+ executor, shape, &literal));
return std::move(literal);
}
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 56c3a3da02..feb2f8ec9d 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -84,8 +84,7 @@ class LocalExecutable {
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
- StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
@@ -132,8 +131,7 @@ class LocalClient : public Client {
// Copy the data from the device contained in the given ShapedBuffer and
// return as a Literal.
- StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
@@ -151,8 +149,8 @@ class LocalClient : public Client {
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal);
+ StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal);
// Returns the device ordinal that corresponds to the given replica number.
//
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 887b970661..95ff6432a5 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
- TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
+ LookUpInstructionByHandle(root_id));
ProgramShape program_shape;
- *program_shape.mutable_result() = instructions_[root_id].shape();
+ *program_shape.mutable_result() = root_proto->shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
return;
}
- CHECK(op_handle < instructions_.size() && op_handle >= 0);
-
- const HloInstructionProto& instr = instructions_[op_handle];
+ const HloInstructionProto& instr =
+ *(LookUpInstructionByHandle(op_handle).ValueOrDie());
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
switch (opcode) {
default:
@@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
// Clear data held by this builder.
this->instructions_.clear();
+ this->handle_to_index_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
@@ -738,7 +739,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
@@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
- // also a valid denpendency order). The related ops will be added to the
+ // also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
@@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
worklist.push(root->id());
related_ops.insert(root->id());
while (!worklist.empty()) {
- int64 node = worklist.front();
+ int64 handle = worklist.front();
worklist.pop();
- for (int64 id : instructions_[node].operand_ids()) {
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+ LookUpInstructionByHandle(handle));
+ for (int64 id : instr_proto->operand_ids()) {
if (related_ops.insert(id).second) {
worklist.push(id);
}
}
- for (int64 called_id : instructions_[node].called_computation_ids()) {
+ for (int64 called_id : instr_proto->called_computation_ids()) {
related_calls.insert(called_id);
}
}
@@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// Add related ops to the computation.
for (int64 id : related_ops) {
auto* instr = entry.add_instructions();
- *instr = instructions_[id];
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
+ LookUpInstructionByHandle(id));
+ *instr = *instr_src;
// Ensures that the instruction names are unique among the graph.
const string& new_name =
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
@@ -2415,11 +2420,11 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
- const int64 handle = instructions_.size();
+ const int64 handle = GetUniqueId();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
- instr.set_name(StrCat(instr.opcode()));
+ instr.set_name(instr.opcode());
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
@@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
*instr.mutable_sharding() = *sharding_;
}
- instructions_.push_back(instr);
+ handle_to_index_[handle] = instructions_.size();
+ instructions_.push_back(std::move(instr));
XlaOp op(handle, this);
return op;
@@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
op.handle(), op.builder_->name(), this->name());
}
- if (op.handle() >= instructions_.size() || op.handle() < 0) {
- return InvalidArgument("no XlaOp value %d", op.handle());
+ return LookUpInstructionByHandle(op.handle());
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
+ int64 handle) const {
+ auto it = handle_to_index_.find(handle);
+ if (it == handle_to_index_.end()) {
+ return InvalidArgument("No XlaOp with handle %d", handle);
}
- return &instructions_[op.handle()];
+ return &instructions_[it->second];
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 58e8f4e7fa..d0c59fa6f2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -955,6 +956,8 @@ class XlaBuilder {
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+ StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
+ int64 handle) const;
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1024,6 +1027,10 @@ class XlaBuilder {
// The instructions of this computation.
std::vector<HloInstructionProto> instructions_;
+ // A map from XlaOp::Handle to the index in the instructions_ vector where the
+ // instruction is held.
+ tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.
@@ -2112,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
- return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2129,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*LiteralUtil::CreateR1(values));
+ return ConstantLiteral(LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -2189,12 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2207,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -2221,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateFromArray<NativeT>(values));
+ LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -2236,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -2253,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3f7635bd40..5035f41988 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
return *this;
}
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = absl::make_unique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
+Literal LiteralBase::CreateFromShape(const Shape& shape) {
+ Literal literal(shape);
+ literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
memset(piece->untyped_data(), 0, piece->size_bytes());
@@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<Literal>>
-MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
+/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
+ const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
- auto literal = absl::make_unique<Literal>(proto.shape());
+ Literal literal(proto.shape());
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
@@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
}
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
+Literal LiteralBase::Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
- auto result = absl::make_unique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
+ Literal result(new_shape);
+ TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
+Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
- result->shape(),
+ result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
+ TF_CHECK_OK(result.CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
}
});
return result;
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
@@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
- std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
- char* dest_data = static_cast<char*>(result->untyped_data());
+ char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
@@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
return std::move(result);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+StatusOr<Literal> LiteralBase::Reshape(
absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
- std::unique_ptr<Literal> output;
+ Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
} else {
- output = CloneToUnique();
+ output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- *output->mutable_shape_do_not_use() =
+ *output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
ShapeUtil::HumanString(shape()),
- ShapeUtil::HumanString(output->shape()));
+ ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- absl::Span<const int64> permutation) const {
+Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
- auto new_literal = absl::make_unique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ Literal new_literal(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
+Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
- auto result_literal = absl::make_unique<Literal>(result_shape);
+ Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- result_literal->EachCell<NativeT>(
+ result_literal.EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
- result_literal->Set<NativeT>(indices, value);
+ result_literal.Set<NativeT>(indices, value);
});
return result_literal;
}
-std::unique_ptr<Literal> LiteralBase::Slice(
- absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const {
+Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
return result;
}
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = absl::make_unique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
@@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
+Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
+ const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
+ Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
+ auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
@@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
+Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
@@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(
+ Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
- absl::Span<complex64> dest_data = result_literal->data<complex64>();
+ absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
+Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
@@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
}
template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
@@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
PrimitiveType_Name(primitive_dest_type));
}
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
+ return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
@@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
@@ -1362,17 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
- if (round_f32_to_bf16 && shape().element_type() == F32 &&
- dest_shape.element_type() == BF16) {
- auto converter = [](float src) {
- return tensorflow::bfloat16::round_to_bfloat16(src);
- };
- return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
- converter);
- }
return Convert(dest_shape.element_type());
}
std::vector<Literal> elements;
@@ -1381,11 +1361,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
+ elements.push_back(std::move(new_element));
}
- auto converted = absl::make_unique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
- return std::move(converted);
+ return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
@@ -1782,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
+ case S8:
+ proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
+ element_count());
+ break;
case U8:
proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
element_count());
@@ -1872,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
break;
+ case S8: {
+ auto s8_data = data<int8>();
+ TF_RET_CHECK(proto.s8s().size() == s8_data.size());
+ std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
+ } break;
case U8: {
auto u8_data = data<uint8>();
TF_RET_CHECK(proto.u8s().size() == u8_data.size());
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index b928cb6374..1e0a2ad0dd 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -217,31 +217,20 @@ class LiteralBase {
// Converts this literal to the given shape. Returns an error is the
// conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+ StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
// width. Returns an error if the conversion is not possible. This literal
// must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
+ // Clones the underlying buffers into a new Literal.
Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
// TODO(b/67651157): The methods below which perform computation on Literals
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@@ -259,24 +248,23 @@ class LiteralBase {
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
+ Literal Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
// An overload of Relayout which changes the layout of the entire shape rather
// than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+ Literal Relayout(const Shape& shape_with_layout) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape, absl::Span<const int64> dimensions) const;
+ StatusOr<Literal> Broadcast(const Shape& result_shape,
+ absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@@ -285,7 +273,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
- std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
+ Literal Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@@ -293,15 +281,15 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
- std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const;
+ Literal Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
// This literal must be an array.
template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
+ Literal Replicate(int64 times) const;
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
@@ -312,7 +300,7 @@ class LiteralBase {
// initialization, then reinitialization. Conside if a call to
// absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+ static Literal CreateFromShape(const Shape& shape);
protected:
// A data structure representing a subshape at a particular ShapeIndex within
@@ -539,8 +527,8 @@ class LiteralBase {
private:
template <typename NativeT>
- std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape, absl::Span<const int64> start_indices) const;
+ Literal SliceInternal(const Shape& result_shape,
+ absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@@ -687,8 +675,7 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
protected:
// Returns the piece at the given ShapeIndex.
@@ -1137,15 +1124,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+Literal LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal.shape());
if (elements == 0) {
return literal;
}
@@ -1157,7 +1143,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
bool done = false;
while (!done) {
const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
+ literal.Set<NativeT>(output_indices, element);
done = true;
for (int n = 0; n < output_indices.size(); ++n) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 1a64594db8..7ad287c897 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test {
Layout layout_r3_dim0minor_;
Layout layout_r4_dim0major_;
Layout layout_r4_dim0minor_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+ Literal literal_r4_2x2x3x3_dim0major_;
+ Literal literal_r4_2x2x3x3_dim0minor_;
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto true_lit = LiteralUtil::CreateR0<bool>(true);
- EXPECT_EQ("true", true_lit->ToString());
+ EXPECT_EQ("true", true_lit.ToString());
auto false_lit = LiteralUtil::CreateR0<bool>(false);
- EXPECT_EQ("false", false_lit->ToString());
+ EXPECT_EQ("false", false_lit.ToString());
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
- EXPECT_EQ("42", u32_lit->ToString());
+ EXPECT_EQ("42", u32_lit.ToString());
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
- EXPECT_EQ("-999", s32_lit->ToString());
+ EXPECT_EQ("-999", s32_lit.ToString());
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
- EXPECT_EQ("3.14", f32_lit->ToString());
+ EXPECT_EQ("3.14", f32_lit.ToString());
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
- EXPECT_EQ("0.5", f16_lit->ToString());
+ EXPECT_EQ("0.5", f16_lit.ToString());
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
- EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString());
+ EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
- EXPECT_EQ("0.5", bf16_lit->ToString());
+ EXPECT_EQ("0.5", bf16_lit.ToString());
// 3.14 will be rounded to 3.14062 in bfloat16 format.
auto bf16_lit_truncated =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- ASSERT_EQ("3.14062", bf16_lit_truncated->ToString());
+ ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
auto bf16_lit_truncated2 =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
- EXPECT_EQ("9", bf16_lit_truncated2->ToString());
+ EXPECT_EQ("9", bf16_lit_truncated2.ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
- EXPECT_EQ("{101}", pred_vec->ToString());
+ EXPECT_EQ("{101}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
@@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) {
{ 3, 4 },
{ 5, 6 }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R3ToString) {
@@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) {
{ { 5 },
{ 6 } }
})";
- EXPECT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, TupleToString) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -171,7 +171,7 @@ f32[2,2] {
{ 3, 4 }
}
))";
- EXPECT_EQ(expected, tuple->ToString());
+ EXPECT_EQ(expected, tuple.ToString());
}
TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
@@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
// clang-format on
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[2,3,2] {
{ { 1, 2 },
{ 3, 4 },
@@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) {
};
std::vector<int64> expected_values = {8, 9, 7, 10};
- EXPECT_EQ(literal->sparse_indices()->data(),
+ EXPECT_EQ(literal.sparse_indices()->data(),
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
- EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
+ EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
{2001, 2002},
}, /*projection_p=*/1, /*projection_z=*/2);
// clang-format on
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[1,2,3,2] {
{ /*i0=0*/
{ /*i1=0*/
@@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
}
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
- EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
+ EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
ElementsAre(2, 2, 3, 3));
- string result = literal_r4_2x2x3x3_dim0major_->ToString();
+ string result = literal_r4_2x2x3x3_dim0major_.ToString();
const string expected = R"(f32[2,2,3,3] {
{ /*i0=0*/
{ /*i1=0*/
@@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
});
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
- literal->EachCellAsString(
+ literal.EachCellAsString(
[&seen](absl::Span<const int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) {
auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
- EXPECT_EQ(*f32_42, *f32_42);
- EXPECT_EQ(*f32_42, *f32_42_clone);
+ EXPECT_EQ(f32_42, f32_42);
+ EXPECT_EQ(f32_42, f32_42_clone);
auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*f32_42, *f32_123);
+ EXPECT_NE(f32_42, f32_123);
auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
- EXPECT_NE(*f32_42, *f64_42);
+ EXPECT_NE(f32_42, f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
@@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(*matrix, *matrix);
- EXPECT_EQ(*matrix, *matrix_clone);
- EXPECT_NE(*matrix, *matrix_different);
- EXPECT_NE(*matrix, *vector_literal);
- EXPECT_NE(*matrix, *scalar);
- EXPECT_NE(*matrix, nil);
+ EXPECT_EQ(matrix, matrix);
+ EXPECT_EQ(matrix, matrix_clone);
+ EXPECT_NE(matrix, matrix_different);
+ EXPECT_NE(matrix, vector_literal);
+ EXPECT_NE(matrix, scalar);
+ EXPECT_NE(matrix, nil);
EXPECT_EQ(nil, nil);
}
@@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) {
auto token1 = LiteralUtil::CreateToken();
auto scalar = LiteralUtil::CreateR0<float>(1.0);
- EXPECT_EQ(*token0, *token1);
- EXPECT_NE(*token0, *scalar);
+ EXPECT_EQ(token0, token1);
+ EXPECT_NE(token0, scalar);
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
- *LiteralUtil::MakeTuple({token0.get()}));
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
- EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
+ LiteralUtil::MakeTuple({&token0}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&token1, &scalar}));
+ EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&scalar, &token1}));
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
- colmajor->Set<float>({0, 0}, 1.0);
- colmajor->Set<float>({0, 1}, 2.0);
- colmajor->Set<float>({1, 0}, 3.0);
- colmajor->Set<float>({1, 1}, 4.0);
+ Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+ colmajor.Set<float>({0, 0}, 1.0);
+ colmajor.Set<float>({0, 1}, 2.0);
+ colmajor.Set<float>({1, 0}, 3.0);
+ colmajor.Set<float>({1, 1}, 4.0);
- auto rowmajor = absl::make_unique<Literal>(
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
- rowmajor->Set<float>({0, 0}, 1.0);
- rowmajor->Set<float>({0, 1}, 2.0);
- rowmajor->Set<float>({1, 0}, 3.0);
- rowmajor->Set<float>({1, 1}, 4.0);
+ Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+ rowmajor.Set<float>({0, 0}, 1.0);
+ rowmajor.Set<float>({0, 1}, 2.0);
+ rowmajor.Set<float>({1, 0}, 3.0);
+ rowmajor.Set<float>({1, 1}, 4.0);
- EXPECT_EQ(*rowmajor, *colmajor);
+ EXPECT_EQ(rowmajor, colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
- auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_EQ(*tuple1, *tuple2);
+ auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
+ EXPECT_EQ(tuple1, tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_NE(*tuple1, *reversed_tuple);
+ auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
+ EXPECT_NE(tuple1, reversed_tuple);
// Tuple with different value.
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
- auto different_tuple =
- LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_NE(*tuple1, *different_tuple);
+ auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
+ EXPECT_NE(tuple1, different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
@@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) {
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone =
LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- EXPECT_EQ(*vector, *vector_clone);
+ EXPECT_EQ(vector, vector_clone);
auto vector_reversed =
LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
- EXPECT_NE(*vector, *vector_reversed);
+ EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = LiteralUtil::CreateR0<float>(0.0);
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+ auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
// Tuples should always return false for IsAll.
- EXPECT_FALSE(tuple->IsAll(0));
- EXPECT_FALSE(tuple->IsAll(1));
+ EXPECT_FALSE(tuple.IsAll(0));
+ EXPECT_FALSE(tuple.IsAll(1));
}
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto scalar = LiteralUtil::CreateR0<float>(0.0);
auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
- auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_EQ(*tuple, *x);
+ auto x = Literal::CreateFromShape(tuple.shape());
+ EXPECT_EQ(tuple, x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
- EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
- EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
- ->IsAll(-1));
+ .IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
- ->IsAllFloat(.5));
+ .IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(
- LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
- EXPECT_TRUE(scalar_zero->IsZero({}));
- EXPECT_FALSE(scalar_one->IsZero({}));
+ EXPECT_TRUE(scalar_zero.IsZero({}));
+ EXPECT_FALSE(scalar_one.IsZero({}));
auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
- EXPECT_FALSE(array->IsZero({0, 1}));
- EXPECT_TRUE(array->IsZero({0, 2}));
- EXPECT_TRUE(array->IsZero({1, 1}));
- EXPECT_FALSE(array->IsZero({1, 2}));
+ EXPECT_FALSE(array.IsZero({0, 1}));
+ EXPECT_TRUE(array.IsZero({0, 2}));
+ EXPECT_TRUE(array.IsZero({1, 1}));
+ EXPECT_FALSE(array.IsZero({1, 2}));
auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
- EXPECT_TRUE(complex_zero->IsZero({}));
- EXPECT_FALSE(complex_nonzero->IsZero({}));
+ EXPECT_TRUE(complex_zero.IsZero({}));
+ EXPECT_FALSE(complex_nonzero.IsZero({}));
}
template <typename T>
@@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
- auto data01 = data->Relayout(layout01);
- EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_EQ(*data, *data01);
+ auto data01 = data.Relayout(layout01);
+ EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
+ EXPECT_EQ(data, data01);
- auto data10 = data->Relayout(layout10);
- EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_EQ(*data, *data10);
+ auto data10 = data.Relayout(layout10);
+ EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
+ EXPECT_EQ(data, data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Transpose(/*permutation=*/{});
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}});
// clang-format on
- auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
+ auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
- EXPECT_EQ(value, original->Get<float>(
+ reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
+ EXPECT_EQ(value, original.Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
}
@@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// Tests that using Relayout on an array is equivalent to creating it in the
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
- literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
+ literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
- literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
+ literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
- EXPECT_EQ(mat_dim0minor->element_count(), 6);
- EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
+ EXPECT_EQ(mat_dim0minor.element_count(), 6);
+ EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
// Test expected memory layout when using Relayout to row major.
- auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
- EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+ auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
+ EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
- EXPECT_EQ(mat_dim0major->element_count(), 6);
- EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
+ EXPECT_EQ(mat_dim0major.element_count(), 6);
+ EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout when using Relayout to column major.
- auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
- EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+ auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
+ EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
ElementsAre(1, 4, 2, 5, 3, 6));
}
@@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0minor_);
- EXPECT_EQ(lit_dim0minor->element_count(), 12);
+ EXPECT_EQ(lit_dim0minor.element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
- EXPECT_THAT(lit_dim0minor->data<int32>(),
+ EXPECT_THAT(lit_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
// Test expected memory layout when using Relayout to row major.
- auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
+ auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
+ EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0major_);
- EXPECT_EQ(lit_dim0major->element_count(), 12);
- EXPECT_THAT(lit_dim0major->data<int32>(),
+ EXPECT_EQ(lit_dim0major.element_count(), 12);
+ EXPECT_THAT(lit_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout when using Relayout to column major.
- auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
- EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
+ auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
+ EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
}
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = LiteralUtil::CreateR0<int32>(1);
- auto result = input->Slice({}, {});
- EXPECT_EQ(*input, *result);
+ auto result = input.Slice({}, {});
+ EXPECT_EQ(input, result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
- auto result = input->Slice({3}, {4});
+ auto result = input.Slice({3}, {4});
auto expected = LiteralUtil::CreateR1<float>({4.0});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto result = input_3x4->Slice({0, 2}, {2, 4});
+ auto result = input_3x4.Slice({0, 2}, {2, 4});
auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_EQ(*input_2x3x2, *result);
+ auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
+ EXPECT_EQ(input_2x3x2, result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
auto expected = LiteralUtil::CreateR1<int64>({77});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
@@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
@@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR0<bfloat16>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
@@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
@@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
auto expected = LiteralUtil::CreateR0<float>(2.5f);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
@@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
output.PopulateWithValue<complex64>({4, 2});
auto expected =
LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR0<half>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR1<half>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
auto input = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto output = input->Replicate<uint32>(3);
+ auto output = input.Replicate<uint32>(3);
auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_EQ(*output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, CopySliceFrom) {
@@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
auto init_proc = [&](absl::Span<const int64> indexes) {
- source->Set(indexes, ++seqnr);
+ source.Set(indexes, ++seqnr);
return true;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
init_proc);
auto blank = Literal::CreateFromShape(shape);
const int64 src_base[] = {3, 1, 5, 7};
const int64 dest_base[] = {6, 4, 12, 2};
const int64 copy_size[] = {7, 8, 11, 9};
- TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
+ TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
@@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
blank_indexes.begin(), std::plus<int64>());
- auto bval = blank->Get<uint32>(blank_indexes);
- matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
+ auto bval = blank.Get<uint32>(blank_indexes);
+ matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
return matched;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
check_proc);
EXPECT_TRUE(matched);
}
@@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
TEST_F(LiteralUtilTest, CopyFromScalars) {
auto zero = LiteralUtil::CreateR0<uint32>(0);
auto nine = LiteralUtil::CreateR0<uint32>(9);
- TF_EXPECT_OK(zero->CopyFrom(*nine));
- EXPECT_EQ(*zero, *nine);
+ TF_EXPECT_OK(zero.CopyFrom(nine));
+ EXPECT_EQ(zero, nine);
auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
- TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
- EXPECT_EQ(zero->Get<uint32>({}), 17);
- TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
- EXPECT_EQ(vect->Get<uint32>({4}), 17);
+ TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
+ EXPECT_EQ(zero.Get<uint32>({}), 17);
+ TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
+ EXPECT_EQ(vect.Get<uint32>({4}), 17);
}
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
@@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
- EXPECT_EQ(*nine, *const_nine);
+ TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
+ EXPECT_EQ(nine, const_nine);
}
{
// Copy 0 element to destination with zero elements.
- const auto empty = Literal::CreateFromShape(empty_r1_shape);
+ auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
- EXPECT_EQ(*empty, *const_empty);
+ TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
+ EXPECT_EQ(empty, const_empty);
}
}
@@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
TEST_F(LiteralUtilTest, CopyFromArrays) {
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*scalar_42, *scalar_123);
- TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*scalar_42, *scalar_123);
- EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+ EXPECT_NE(scalar_42, scalar_123);
+ TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(scalar_42, scalar_123);
+ EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
- EXPECT_NE(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
- TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+ EXPECT_NE(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
+ TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {matrix.get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get()});
+ Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0})};
+ Literal inner_tuple = LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal});
+ Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(-5).get(),
- LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
+ Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
+ Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
+ Literal tuple =
+ LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
// Overwrite the inner tuple element of nested_tuple with the contents of
// 'tuple'.
- TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{}));
+ TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
- LiteralUtil::CreateR0<int32>(4).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
+ LiteralUtil::CreateR0<int32>(4)};
+ Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
// Copy from one element to the other.
- TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{0}));
+ TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{0}));
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
- Status status = matrix->CopyFrom(*vector);
+ Status status = matrix.CopyFrom(vector);
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Destination subshape incompatible"));
@@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) {
// Verify that the internal data views are consistent and that they
// are in little endian format
// TODO - modify if we make the data format machine endianess dependent
- auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
- Literal* l1 = m1.get();
- const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
+ Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
+ const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
EXPECT_EQ(d1[0], 0);
EXPECT_EQ(d1[1], 0);
EXPECT_EQ(d1[2], 0);
@@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l2 = m2.get();
- const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
+ const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
EXPECT_EQ(d2[0], 0);
EXPECT_EQ(d2[1], 0x3C);
EXPECT_EQ(d2[2], 0);
@@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->Populate<uint32>(generator));
+ TF_EXPECT_OK(literal.Populate<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+ TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
auto check_function = [&](absl::Span<const int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// clang-format on
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->Convert(U32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
- EXPECT_EQ(*expected, *converted);
+ EXPECT_EQ(expected, converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
- std::unique_ptr<Literal> conv;
+ Literal conv;
- conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u32);
+ conv = s8.Convert(U32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u32);
- conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = s8.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u64);
+ conv = s8.Convert(U64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u64);
- conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s64);
+ conv = s8.Convert(S64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s64);
- conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *pred);
+ conv = s8.Convert(PRED).ConsumeValueOrDie();
+ EXPECT_EQ(conv, pred);
- conv = bf16->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = bf16.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = bf16->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = bf16.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *int32_pred);
+ conv = pred.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, int32_pred);
- conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f32.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f64.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = s32.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f64.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = s32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = u32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = s32.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- conv = f16->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = f16.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- EXPECT_EQ(s32->Convert(TUPLE).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(S16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(U16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(F32).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(S32).status().code(),
+ EXPECT_EQ(s32.Convert(TUPLE).status().code(),
tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
}
TEST_F(LiteralUtilTest, BitcastConvert) {
@@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) {
tensorflow::bit_cast<uint32>(100.f), 0xbeef});
auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->BitcastConvert(F32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
auto literal = LiteralUtil::CreateR0<uint32>(1234);
- Status status = literal->BitcastConvert(F64).status();
+ Status status = literal.BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(
absl::StrContains(status.error_message(), "bit widths are different"));
@@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
p.add_preds((i % 2) == (len % 2));
}
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- ASSERT_EQ(len, literal->data<bool>().size());
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ ASSERT_EQ(len, literal.data<bool>().size());
int i = 0;
- for (bool value : literal->data<bool>()) {
+ for (bool value : literal.data<bool>()) {
EXPECT_EQ((i % 2) == (len % 2), value);
++i;
}
@@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h2(2.0f);
auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l = m.get();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->data<half>().size());
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
+ EXPECT_EQ(4, m.data<half>().size());
- LiteralProto p = l->ToProto();
+ LiteralProto p = m.ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
EXPECT_EQ(8, p.f16s().size());
const char* d = p.f16s().data();
@@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- auto r = literal->data<half>();
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ auto r = literal.data<half>();
ASSERT_EQ(4, r.size());
EXPECT_EQ(h1, r[0]);
EXPECT_EQ(h2, r[1]);
@@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
+ EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
+ EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 1.0f);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 1.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
1.0f);
- nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 555.0f);
+ nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 555.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
555.0f);
@@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view,
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
}
TEST_F(LiteralUtilTest, LiteralMove) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- Literal literal(std::move(*matrix));
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal(std::move(matrix));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get(),
- &nil_literal});
-
- EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
- std::vector<Literal> elements = nested_tuple->DecomposeTuple();
- EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+ Literal inner_elements[] = {
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}),
+ };
+ Literal tuple_elements[] = {
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
+ LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal}),
+ };
+ Literal nested_tuple = LiteralUtil::MakeTuple(
+ {&tuple_elements[0], &tuple_elements[1], &nil_literal});
+
+ EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape()));
+ std::vector<Literal> elements = nested_tuple.DecomposeTuple();
+ EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape()));
ASSERT_EQ(elements.size(), 3);
@@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
- elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
-
- ));
+ elements.push_back(LiteralUtil::CreateR0<float>(1.0));
+ elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
+ std::vector<Literal> inner_elements;
+ inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
+ inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
+ elements.push_back(
+ LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
Literal literal;
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- literal = std::move(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ literal = std::move(matrix);
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
}
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralSlice(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ const auto matrix_view = LiteralSlice(matrix);
LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
@@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(42.0).get(),
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
- tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
-
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
- 3.0);
- tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ Literal elements[] = {
+ LiteralUtil::CreateR0<float>(42.0),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ };
+ auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+ tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
+ tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
-4.0);
}
TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
// Literals constructed using CreateFromShape should be zero initialized.
- std::unique_ptr<Literal> scalar_f32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
- EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
- EXPECT_TRUE(scalar_f32->IsAll(0));
-
- std::unique_ptr<Literal> vector_s32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
- EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
- EXPECT_TRUE(vector_s32->IsAll(0));
-
- std::unique_ptr<Literal> tuple =
- Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
- ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
-
- EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
- EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
- EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
- EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+ Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+ EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
+ EXPECT_TRUE(scalar_f32.IsAll(0));
+
+ Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+ EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
+ EXPECT_TRUE(vector_s32.IsAll(0));
+
+ Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+ ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+ EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
+ EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
+ EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
+ EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
}
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@@ -1657,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
@@ -1665,25 +1649,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto matrix_pred =
LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
auto tuple = LiteralUtil::MakeTuple(
- {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+ {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+ auto nested_tuple =
+ LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
- return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+ return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
};
- EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
- EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
- EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
- EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
- EXPECT_EQ(*tuple, to_from_proto(*tuple));
- EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+ EXPECT_EQ(one_f32, to_from_proto(one_f32));
+ EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
+ EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
+ EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
+ EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
+ EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
+ EXPECT_EQ(tuple, to_from_proto(tuple));
+ EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
- EXPECT_NE(*one_f32, *two_f32);
- EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+ EXPECT_NE(one_f32, two_f32);
+ EXPECT_NE(one_f32, to_from_proto(two_f32));
}
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
@@ -1802,11 +1788,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
TEST_F(LiteralUtilTest, SortSparseElements) {
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
SparseIndexArray(10, 3), {});
- literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
- literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
- literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
- literal->SortSparseElements();
- EXPECT_EQ(literal->ToString(false),
+ literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
+ literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
+ literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
+ literal.SortSparseElements();
+ EXPECT_EQ(literal.ToString(false),
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
}
@@ -1816,57 +1802,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
EXPECT_EQ(
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
"false");
EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(int64{2}));
EXPECT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(double{2.0}));
EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat(static_cast<float>(half{2.0})));
EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
+ Literal literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
- /*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 613449cf10..0cb1ae35f4 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -45,7 +45,7 @@ using absl::StrCat;
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+Literal ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
@@ -56,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
@@ -67,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
- auto dest = result->data<ToNativeT>(shape_index);
+ auto dest = result.data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
- TF_CHECK_OK(result->CopyFrom(literal,
- /*dest_shape_index=*/shape_index,
- /*src_shape_index=*/shape_index));
+ TF_CHECK_OK(result.CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
}
}
});
@@ -83,53 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
+/* static */ Literal LiteralUtil::CreateFromDimensions(
PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
return Literal::CreateFromShape(
ShapeUtil::MakeShape(primitive_type, dimensions));
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
+/* static */ Literal LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
+/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
- return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
+/* static */ Literal LiteralUtil::CreateToken() {
+ return Literal(ShapeUtil::MakeTokenShape());
}
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(0));
+ return LiteralUtil::CreateR0<uint8>(0);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(0));
+ return LiteralUtil::CreateR0<uint32>(0);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(0));
+ return LiteralUtil::CreateR0<uint64>(0);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(0));
+ return LiteralUtil::CreateR0<int8>(0);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(0));
+ return LiteralUtil::CreateR0<int32>(0);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(0));
+ return LiteralUtil::CreateR0<int64>(0);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(0));
+ return LiteralUtil::CreateR0<float>(0);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(0));
+ return LiteralUtil::CreateR0<double>(0);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(0));
+ return LiteralUtil::CreateR0<complex64>(0);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -145,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(1));
+ return LiteralUtil::CreateR0<uint8>(1);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(1));
+ return LiteralUtil::CreateR0<uint32>(1);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(1));
+ return LiteralUtil::CreateR0<uint64>(1);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(1));
+ return LiteralUtil::CreateR0<int8>(1);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(1));
+ return LiteralUtil::CreateR0<int32>(1);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(1));
+ return LiteralUtil::CreateR0<int64>(1);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(1));
+ return LiteralUtil::CreateR0<float>(1);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(1));
+ return LiteralUtil::CreateR0<double>(1);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(1));
+ return LiteralUtil::CreateR0<complex64>(1);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -184,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- -std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- -std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -232,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity());
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -275,31 +261,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+/* static */ Literal LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
- auto literal = absl::make_unique<Literal>(
+ Literal literal(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
- absl::string_view value) {
- auto literal = absl::make_unique<Literal>(
- ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
+ Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
- literal->Set<uint8>({i}, value[i]);
+ literal.Set<uint8>({i}, value[i]);
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
- float from, float to, int64 rows, int64 cols) {
+/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
+ int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
+/* static */ Literal LiteralUtil::ReshapeSlice(
absl::Span<const int64> new_dimensions,
absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
int64 new_num_elements = 1;
@@ -309,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
- auto new_literal = absl::make_unique<Literal>(
+ Literal new_literal(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
- Shape shape_with_layout = new_literal->shape();
+ Shape shape_with_layout = new_literal.shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
@@ -326,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
- new_literal->Set<bool>(to_multi_index,
- literal.Get<bool>(from_multi_index));
+ new_literal.Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
break;
case U8:
- new_literal->Set<uint8>(to_multi_index,
- literal.Get<uint8>(from_multi_index));
+ new_literal.Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
break;
case U32:
- new_literal->Set<uint32>(to_multi_index,
- literal.Get<uint32>(from_multi_index));
+ new_literal.Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
break;
case S32:
- new_literal->Set<int32>(to_multi_index,
- literal.Get<int32>(from_multi_index));
+ new_literal.Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
break;
case U64:
- new_literal->Set<uint64>(to_multi_index,
- literal.Get<uint64>(from_multi_index));
+ new_literal.Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
break;
case S64:
- new_literal->Set<int64>(to_multi_index,
- literal.Get<int64>(from_multi_index));
+ new_literal.Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
break;
case F32:
- new_literal->Set<float>(to_multi_index,
- literal.Get<float>(from_multi_index));
+ new_literal.Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
break;
case F64:
- new_literal->Set<double>(to_multi_index,
- literal.Get<double>(from_multi_index));
+ new_literal.Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
break;
case C64:
- new_literal->Set<complex64>(to_multi_index,
- literal.Get<complex64>(from_multi_index));
+ new_literal.Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
@@ -376,97 +360,82 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
switch (literal.shape().element_type()) {
case PRED:
- return std::move(
- *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
+ return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
// 8 bit types.
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
+ return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
+ return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
// 16 bit types.
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- literal.GetFirstElement<bfloat16>()));
+ return LiteralUtil::CreateR0<bfloat16>(
+ literal.GetFirstElement<bfloat16>());
case F16:
- return std::move(
- *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
+ return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
case S16:
- return std::move(
- *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
+ return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
case U16:
- return std::move(
- *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
+ return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
// 32 bit types.
case F32:
- return std::move(
- *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
+ return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
+ return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
+ return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
// 64 bit types.
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(
- literal.GetFirstElement<complex64>()));
+ return LiteralUtil::CreateR0<complex64>(
+ literal.GetFirstElement<complex64>());
case F64:
- return std::move(
- *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
+ return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
+ return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
+ return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
default:
LOG(FATAL) << "Unhandled primitive type "
<< literal.shape().element_type();
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
+/* static */ Literal LiteralUtil::MakeTuple(
absl::Span<const Literal* const> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
+/* static */ Literal LiteralUtil::MakeTupleFromSlices(
absl::Span<const LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements) {
+/* static */ Literal LiteralUtil::MakeTupleOwned(
+ std::vector<Literal> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
for (const auto& element : elements) {
- element_shapes.push_back(element->shape());
+ element_shapes.push_back(element.shape());
}
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int64 i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(
- literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
}
return literal;
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 2d6084a67a..2b181621ed 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -69,36 +69,34 @@ class LiteralUtil {
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR0(NativeT value);
+ static Literal CreateR0(NativeT value);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
- static std::unique_ptr<Literal> CreateR1(
- const tensorflow::core::Bitmap& values);
+ static Literal CreateR1(absl::Span<const NativeT> values);
+ static Literal CreateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2(
+ static Literal CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2WithLayout(
+ static Literal CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3(
- std::initializer_list<
- std::initializer_list<std::initializer_list<NativeT>>>
- values);
+ static Literal CreateR3(std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3WithLayout(
+ static Literal CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4(
+ static Literal CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4WithLayout(
+ static Literal CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -139,9 +137,10 @@ class LiteralUtil {
// [9, 10, 11]: 4.0
//
template <typename NativeT>
- static std::unique_ptr<Literal> CreateSparse(
- absl::Span<const int64> dimensions, SparseIndexArray indices,
- absl::Span<const NativeT> values, bool sort = true);
+ static Literal CreateSparse(absl::Span<const int64> dimensions,
+ SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@@ -155,130 +154,120 @@ class LiteralUtil {
static Literal MaxValue(PrimitiveType primitive_type);
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
+ static Literal CreateFullWithDescendingLayout(
absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+ static Literal CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
+ static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2D(
- const Array2D<NativeT>& values);
+ static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
+ static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3D(
- const Array3D<NativeT>& values);
+ static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
+ static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4D(
- const Array4D<NativeT>& values);
+ static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
+ static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
// Creates a new vector of U8s literal value from a string.
- static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
+ static Literal CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
- static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
- int64 rows, int64 cols);
+ static Literal CreateR2F32Linspace(float from, float to, int64 rows,
+ int64 cols);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3Projected(
+ static Literal CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4Projected(
+ static Literal CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
- static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+ static Literal MakeIdentityR2(int64 size);
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
- static std::unique_ptr<Literal> MakeTuple(
- absl::Span<const Literal* const> elements);
+ static Literal MakeTuple(absl::Span<const Literal* const> elements);
- static std::unique_ptr<Literal> MakeTupleFromSlices(
- absl::Span<const LiteralSlice> elements);
+ static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
- // std::vector<std::unique_ptr<Literal>> elements = ...;
+ // std::vector<Literal> elements = ...;
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements);
+ static Literal MakeTupleOwned(std::vector<Literal> elements);
- // This overload lets you pass a braced list of unique_ptr<Literal>s to
+ // This overload lets you pass a braced list of Literals to
// MakeTupleOwned:
//
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
- // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
+ // Simply relying on the MakeTupleOwned(std::vector<Literal>)
// overload doesn't work because std::initializer_list's elements are always
// const.
//
- // The arguments to this function must all be unique_ptr<Literal>.
+ // The arguments to this function must all be Literal.
template <typename... Ts>
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::unique_ptr<Ts>... elements) {
- std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
- std::move(elements)...};
- std::vector<std::unique_ptr<Literal>> v;
+ static Literal MakeTupleOwned(Ts... elements) {
+ std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
+ std::vector<Literal> v;
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
return MakeTupleOwned(std::move(v));
}
// Create a constant token literal. Token types have no value.
- static std::unique_ptr<Literal> CreateToken();
+ static Literal CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type, absl::Span<const int64> dimensions);
+ static Literal CreateFromDimensions(PrimitiveType primitive_type,
+ absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(
- const LiteralSlice& bf16_literal);
+ static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(
- const LiteralSlice& f32_literal);
+ static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
- static std::unique_ptr<Literal> ReshapeSlice(
- absl::Span<const int64> new_dimensions,
- absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
+ static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major,
+ const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@@ -286,7 +275,7 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ static StatusOr<Literal> CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator);
@@ -297,8 +286,8 @@ class LiteralUtil {
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
+ T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -307,8 +296,8 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
+ T stddev);
//
// End of factory methods.
@@ -322,44 +311,43 @@ class LiteralUtil {
std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
+/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
- literal->Set({}, value);
+ literal.Set({}, value);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
- absl::Span<const NativeT> values) {
- auto literal = absl::make_unique<Literal>(
+/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
+ Literal literal(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+/* static */ Literal LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateR2(values);
+ literal.PopulateR2(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+/* static */ Literal LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+/* static */ Literal LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -384,14 +372,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+/* static */ Literal LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+/* static */ Literal LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -422,23 +410,22 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
+/* static */ Literal LiteralUtil::CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
- indices.max_indices()));
- literal->PopulateSparse(indices, values, sort);
+ Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
+ indices.max_indices()));
+ literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -446,50 +433,48 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
+/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateFromArray(values);
+ literal.PopulateFromArray(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
+/* static */ Literal LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+/* static */ Literal LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+/* static */ Literal LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+/* static */ Literal LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -514,7 +499,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+/* static */ Literal LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -542,21 +527,20 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+/* static */ Literal LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -565,33 +549,29 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
- NativeT value) {
- auto literal =
- absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
- literal->PopulateWithValue(value);
+/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
+ absl::Span<const int64> dimensions, NativeT value) {
+ Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+ literal.PopulateWithValue(value);
return literal;
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
- auto literal = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+ Literal literal(shape);
+ TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index bddb664149..0f86f9f35e 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -40,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
-StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
- const Shape& shape, const Layout* layout) {
+StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
+ const Layout* layout) {
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
@@ -58,14 +57,14 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
PrimitiveType_Name(shape.element_type()));
}
- auto result = absl::make_unique<Literal>(literal_shape);
- result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+ Literal result(literal_shape);
+ result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
- absl::Span<const float> field = result->data<float>();
+ absl::Span<const float> field = result.data<float>();
char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- tensorflow::StringPiece sp;
+ absl::string_view sp;
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -86,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- tensorflow::StringPiece sp;
+ absl::string_view sp;
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 98dccaa9a2..d6d2ff1521 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -41,8 +41,7 @@ class PackedLiteralReader {
//
// Layout is optional. If it is not provided, no layout is set on the literal
// that is produced.
- StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
- const Layout* layout = nullptr);
+ StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
// Returns whether the input file has been fully exhausted; i.e. all available
// packed literals have been read and we're at the end of the file.
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index 787725e884..b507a2ef79 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
@@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) {
return safe_file_name;
}
+std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*>
+GetDirectoryExpanders() {
+ static auto* mutex = new tensorflow::mutex;
+ static auto* singleton = new std::vector<std::function<string(string)>>;
+ return {mutex, singleton};
+}
+
+// Runs all the directory expanders over x and returns the result.
+string Expand(string x) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ for (const auto& f : *pair.second) {
+ x = f(x);
+ }
+ return x;
+}
+
} // namespace
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default();
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+ string expanded_dir = Expand(directory);
+ TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir));
string safe_file_name = SanitizeFileName(file_name) + ".pb";
- const string path = tensorflow::io::JoinPath(directory, safe_file_name);
+ const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name);
return tensorflow::WriteBinaryProto(env, path, message);
}
+void RegisterDirectoryExpander(const std::function<string(string)>& expander) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ pair.second->push_back(expander);
+}
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index 3667621367..f22fc8b849 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
+// Registers a function that may either expand a dirpath or forward the original
+// dirpath along as-is.
+void RegisterDirectoryExpander(const std::function<string(string)>& expander);
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd6e20b693..9da5dc0d2d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
return client->TransferToInfeedLocal(literal, device_ordinal);
}
-StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number) {
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number) {
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
<< " shape: " << shape;
LocalClient* client = GetOrCreateLocalClient();
@@ -141,9 +141,8 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
LocalClient* client = GetOrCreateLocalClient();
StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- return ToBuffer(client, /*device_ordinal=*/0, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ return ToBuffer(client, /*device_ordinal=*/0, relaid);
}
return ToBuffer(client, /*device_ordinal=*/0, argument);
}();
@@ -151,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}
-StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
+StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
LocalClient* client = GetOrCreateLocalClient();
return client->ShapedBufferToLiteral(*shaped_buffer());
}
@@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation(
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
-StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
+StatusOr<Literal> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
@@ -169,7 +168,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
// Each replica populates a StatusOr result, but only replica zero actually
// retrieves its literal value.
- std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+ std::vector<StatusOr<Literal>> results(GetReplicaCount());
{
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
GetReplicaCount());
@@ -198,9 +197,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
StatusOr<ScopedShapedBuffer> pushed;
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- pushed = ToBuffer(client, device_ordinal, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ pushed = ToBuffer(client, device_ordinal, relaid);
} else {
pushed = ToBuffer(client, device_ordinal, argument);
}
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 78b3c598b9..1d5dfe5911 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
-StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number);
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number);
// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
@@ -65,7 +65,7 @@ class LocalShapedBuffer {
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
const ScopedShapedBuffer* shaped_buffer() const;
- StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
+ StatusOr<Literal> ToLiteral() const;
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
// analogous to std::unique_ptr::release().
@@ -117,7 +117,7 @@ class CompiledLocalComputation {
// with optionally-specified argument layouts. The literals will be
// re-laid out according to the corresponding elements of
// shapes_with_layout.
- StatusOr<std::unique_ptr<Literal> > Execute(
+ StatusOr<Literal> Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape> >& shapes_with_layout);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 450d3fe5af..521490e76c 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -216,9 +216,9 @@ tensorflow::ImportNumpy();
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
- std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
+ Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@@ -346,25 +346,25 @@ tensorflow::ImportNumpy();
// Literal
-%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
+%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
SWIG_fail;
}
- $1 = literal_status.ValueOrDie().get();
+ $1 = &literal_status.ValueOrDie();
}
-%typemap(out) std::unique_ptr<Literal> {
+%typemap(out) Literal {
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
- $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+ $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
@@ -375,13 +375,13 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
+ StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
Py_DECREF(o);
SWIG_fail;
}
- temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
+ temps.push_back(literal_status.ConsumeValueOrDie());
Py_DECREF(o);
}
$1 = &temps;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index fc6511bef5..b0aa024c74 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
}
}
-StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
if (PyTuple_Check(o)) {
int num_elements = PyTuple_Size(o);
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
elements.reserve(num_elements);
for (int i = 0; i < num_elements; i++) {
PyObject* element = PyTuple_GetItem(o, i);
@@ -389,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
int np_type = PyArray_TYPE(py_array);
auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
- TF_RETURN_IF_ERROR(
- CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
+ TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
return std::move(literal);
} else {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 8cae175185..40ff2d9ad2 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
-StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
// The following functions copy array data from the buffers underlying Numpy
// ndarrays into those underlying XLA literals, and vice versa.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 9f1afa2671..ceb5e74db7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
/* static */ std::unique_ptr<std::vector<float>>
ReferenceUtil::ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
@@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric(
}
/* static */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
- float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
@@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
ReferenceUtil::ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.height(), operand.width()};
std::vector<int64> window_counts(window.size(), 0);
@@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{operand.height(), operand.width()};
return ReduceWindow2DGeneric(
@@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
return ReduceWindow4DGeneric(
@@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
@@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
}
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
padding);
@@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
const Array4D<float>& source,
float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
@@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim;
dim.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
@@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim2;
dim2.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
@@ -565,7 +558,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
*window.add_dimensions() = dim2;
const Shape& shape = ShapeInference::InferConvolveShape(
- lhs_literal->shape(), rhs_literal->shape(),
+ lhs_literal.shape(), rhs_literal.shape(),
/*feature_group_count=*/1, window, dnums)
.ConsumeValueOrDie();
@@ -585,18 +578,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
- CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
+ CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
auto result =
- absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
- result_literal->shape().dimensions(1),
- result_literal->shape().dimensions(2),
- result_literal->shape().dimensions(3));
+ absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
+ result_literal.shape().dimensions(1),
+ result_literal.shape().dimensions(2),
+ result_literal.shape().dimensions(3));
result->Each([&](absl::Span<const int64> indices, float* value) {
- *value = result_literal->Get<float>(indices);
+ *value = result_literal.Get<float>(indices);
});
return result;
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 9ce098029d..8654fbb9b5 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -177,47 +177,41 @@ class ReferenceUtil {
// Windowed reductions with Add as the function to apply.
static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
- const absl::Span<const float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ absl::Span<const float> operand, float init,
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
// Windowed reductions with a generic reduce function.
static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding);
// With arbitrary padding.
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
// Batch normalize data.
static std::unique_ptr<Array4D<float>> BatchNorm4D(
@@ -230,8 +224,8 @@ class ReferenceUtil {
// TODO(b/74533103) Switch tests to evaluator and remove this implementation.
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
const Array4D<float>& operand, const Array4D<float>& source, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, bool same_padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ bool same_padding);
// Concatenates the lhs and rhs arrays along the concatenate_dimension.
// E.g. if concatenate_dimension is 0, the "n1"/height dimension is
@@ -332,8 +326,8 @@ class ReferenceUtil {
// Slices with index clamping
template <typename T>
- static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
- int64 start, int64 size) {
+ static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64 start,
+ int64 size) {
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
std::vector<T> result;
for (int64 i = 0; i < size; ++i) {
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 3ec0192148..a1b0f4045f 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MatmulArray2D) {
@@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
- LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({0}, result);
}
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
+ LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
ErrorSpec(0.0001));
}
@@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray4D) {
@@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray3D) {
@@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
+ {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
ErrorSpec(0.0001));
}
@@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
- *actual_literal, ErrorSpec(0.0001));
+ {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray4D) {
@@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
@@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
{{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
@@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
} // namespace
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 43fd8fe1bd..84fe5b17d1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<float>(expected);
+ Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+ EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
ErrorSpec(0.0001)));
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e784663ff6..fb80c78f68 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -87,6 +87,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -123,6 +124,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -352,6 +354,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -402,6 +405,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -498,6 +502,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -546,6 +551,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -568,6 +574,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1012,8 +1019,8 @@ cc_library(
":buffer_value_containers",
":heap_simulator",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_proto",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1041,8 +1048,8 @@ tf_cc_test(
":cpu_plugin",
":flatten_call_graph",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1088,8 +1095,8 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dataflow_analysis",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1131,6 +1138,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -1139,6 +1147,37 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_module_group",
+ srcs = ["hlo_module_group.cc"],
+ hdrs = ["hlo_module_group.h"],
+ deps = [
+ ":hlo",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_module_group_test",
+ srcs = ["hlo_module_group_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_module_group",
+ ":hlo_parser",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "hlo_module_group_metadata",
srcs = ["hlo_module_group_metadata.cc"],
hdrs = ["hlo_module_group_metadata.h"],
@@ -1185,9 +1224,9 @@ tf_cc_test(
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1199,13 +1238,14 @@ tf_cc_test(
)
cc_library(
- name = "hlo_scheduling",
- srcs = ["hlo_scheduling.cc"],
- hdrs = ["hlo_scheduling.h"],
+ name = "hlo_memory_scheduler",
+ srcs = ["hlo_memory_scheduler.cc"],
+ hdrs = ["hlo_memory_scheduler.h"],
deps = [
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1219,15 +1259,15 @@ cc_library(
)
tf_cc_test(
- name = "hlo_scheduling_test",
- srcs = ["hlo_scheduling_test.cc"],
+ name = "hlo_memory_scheduler_test",
+ srcs = ["hlo_memory_scheduler_test.cc"],
deps = [
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1259,6 +1299,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1392,6 +1433,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -1708,6 +1750,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -1777,6 +1820,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -1953,6 +1997,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
+ ":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -2236,6 +2281,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2314,6 +2360,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -2394,12 +2441,11 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
- ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -2428,6 +2474,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -2494,6 +2541,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2611,6 +2659,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -2888,6 +2937,7 @@ tf_cc_test(
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 3d18fe3be2..5458159d14 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
return scalar_add_computation_;
}
+ // Tries to fold a kPad in the input or filter into the convolution
+ // instruction's window.
+ StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
+ StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
+
+ // Tries to use a kDot in place of the given convolution.
+ StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot strength reduction on platforms where it causes a slowdown.
bool enable_dot_strength_reduction_;
- // Disable convolution simplification on platforms where it causes a slowdown.
+ // Disable convolution -> dot simplification on platforms where it causes a
+ // slowdown.
bool enable_conv_simplification_;
// Cached computation for adding two scalar F32.
@@ -527,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ HloInstruction::CreateConstant(literal.Clone()));
}
}
@@ -546,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
+ Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -676,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
- HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ HloInstruction::CreateConstant((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@@ -1469,7 +1478,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
}
@@ -1572,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- LiteralUtil::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1607,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -2057,12 +2066,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (pad_literal == reduce_init_literal) {
return true;
}
- auto converted_pad_literal = pad_literal.ConvertToShape(
- reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+ auto converted_pad_literal =
+ pad_literal.ConvertToShape(reduce_init_value->shape());
if (!converted_pad_literal.ok()) {
return false;
}
- return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ return converted_pad_literal.ValueOrDie() == reduce_init_literal;
};
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
@@ -2212,170 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleConvolution(
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
HloInstruction* convolution) {
- auto lhs = convolution->mutable_operand(0);
- auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
- ShapeUtil::IsZeroElementArray(rhs->shape())) {
- return ReplaceWithNewInstruction(
- convolution,
- HloInstruction::CreateBroadcast(
- convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(convolution->shape().element_type())
- .CloneToUnique())),
- {}));
- }
-
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
const auto& window = convolution->window();
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
- // Try to merge padding/dilation of the input with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
- if (lhs->opcode() != HloOpcode::kPad) {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(lhs->operand(1), 0)) {
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
return false;
}
- const auto& padding = lhs->padding_config();
-
- // Can't pad batch or feature dims.
- for (int64 dim :
- {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
}
+ }
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = window;
- for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
- // Edge padding composes with itself in the straightforward way, but
- // composing interior padding is nontrivial, and we cowardly refuse to
- // think about it. If we see interior padding in either the kPad or conv,
- // bail if there's any sort of padding in the other.
- if (p.interior_padding() != 0 &&
- (w.padding_low() != 0 || w.padding_high() != 0 ||
- w.base_dilation() != 1)) {
- return false;
- }
- if (w.base_dilation() != 1 &&
- (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0)) {
- return false;
- }
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- w.set_padding_low(w.padding_low() + p.edge_padding_low());
- w.set_padding_high(w.padding_high() + p.edge_padding_high());
- if (p.interior_padding() != 0) {
- CHECK_EQ(w.base_dilation(), 1);
- w.set_base_dilation(1 + p.interior_padding());
- }
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs->mutable_operand(0), rhs});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
- if (folded_input_pad) {
- return Status::OK();
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
}
- // Try to merge dilation of the filter with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
- if (rhs->opcode() != HloOpcode::kPad) {
- return false;
- }
+ const auto& padding = rhs->padding_config();
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(rhs->operand(1), 0)) {
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- const auto& padding = rhs->padding_config();
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
- // Can't pad or dilate feature dims.
- for (int64 dim : {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
}
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = convolution->window();
- for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
-
- // We can only do this transformation if p adds dilation to the filter --
- // edge padding on the filter is not supported in conv.
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
- return false;
- }
-
- // Nothing to do if the kPad for this dim is entirely a nop.
- if (p.interior_padding() == 0) {
- continue;
- }
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
- // We cowardly refuse to think about how dilation composes with itself;
- // bail if both the kPad and conv have dilation on this dimension.
- if (w.window_dilation() > 1) {
- return false;
- }
- CHECK_EQ(w.window_dilation(), 1);
- w.set_window_dilation(1 + p.interior_padding());
- w.set_size(rhs->operand(0)->shape().dimensions(
- dnums.kernel_spatial_dimensions(dim)));
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
}
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs, rhs->mutable_operand(0)});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- if (folded_filter_pad) {
- return Status::OK();
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
if (!enable_conv_simplification_) {
- return Status::OK();
+ return false;
}
- // HandleConvolution tries to replace a convolution with a DOT instruction.
- //
- // Only add when bitcasts can be used:
- // - if bitcasts are not supported, then reshapes could be used but will
- // end up with another copy.
- // - if bitcasts are supported, the simplifier will be called again with
- // bitcasts_ == true.
- // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
+ // layout-insensitive mode, for fear of adding nontrivial reshapes.
if (!is_layout_sensitive_) {
- return Status::OK();
+ return false;
}
const Shape& input_shape = lhs->shape();
@@ -2388,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// Require the spatial dimensions in the kernel to have a bound of one.
for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
- return Status::OK();
+ return false;
}
}
@@ -2399,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) {
- return Status::OK();
+ return false;
}
// Also, the shapes must align for a rowmajor matmul:
@@ -2425,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dnums.kernel_input_feature_dimension()) <
PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
dnums.kernel_output_feature_dimension()))) {
- return Status::OK();
+ return false;
}
auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
@@ -2467,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
- return Status::OK();
+ return false;
}
auto new_lhs = add_bitcast(new_input_shape, lhs);
@@ -2479,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
convolution->precision_config()));
- return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution) {
+ // Zero-sized input or filter.
+ if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
+ return ReplaceWithNewInstruction(
+ convolution,
+ HloInstruction::CreateBroadcast(
+ convolution->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(convolution->shape().element_type()))),
+ {}));
+ }
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
+ // Try to replace the convolution with a kDot instruction.
+ TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
+ if (replaced_with_dot) {
+ return Status::OK();
+ }
+
+ return Status::OK();
}
bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index a0db4563fb..3fc1ba2427 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector)};
+ Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec281ae68f..30d33e0d35 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
const Shape feature_shape = scale->shape();
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
const Shape feature_shape = scale->shape();
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
@@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
@@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto elements_per_feature_literal =
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
+ elements_per_feature_literal.Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aba0d9bb5b..f7ac8f5482 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
-using BatchNormExpanderTest = HloTestBase;
+using BatchNormExpanderTest = HloVerifiedTestBase;
// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
@@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -126,13 +126,13 @@ ENTRY entry {
epsilon=0.001, feature_index=1, sharding={maximal device=1}
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());
- for (auto* instruction : module->entry_computation()->instructions()) {
+ for (auto* instruction : module().entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 6363a21c3b..5f93740887 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
protected:
+ BFloat16ConversionFoldingTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
@@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_EQ(tuple->operand(0), gte_a);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 933cf873e0..cef0eba14e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
protected:
+ BFloat16NormalizationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
@@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(Normalize(module.get()));
+ EXPECT_FALSE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(dot->shape().element_type(), F32);
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 545a6ecfb1..58f78f8e24 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
continue;
}
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
- TF_ASSIGN_OR_RETURN(
- auto converted_literal,
- hlo->literal().ConvertToShape(hlo->shape(),
- /*round_f32_to_bf16=*/true));
+ TF_ASSIGN_OR_RETURN(auto converted_literal,
+ hlo->literal().ConvertToShape(hlo->shape()));
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 388fd5df99..e032b5c624 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 0f0af57626..65fa951afe 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 5a231c173d..795beb9ff5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -30,11 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
+ Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()})));
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 414bfe7999..17e5090505 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()});
- auto inner_tuple1 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+ Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
+ auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
+ Literal element1 = LiteralUtil::CreateR0<int64>(3);
+ auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- inner_tuple0->shape(), tuple_constant, 0));
+ inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index cc80b74843..34f3f914d5 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,7 +31,7 @@ namespace {
using ::testing::UnorderedElementsAre;
-class CallGraphTest : public HloTestBase {
+class CallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(
@@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(1, call_graph->nodes().size());
EXPECT_TRUE(call_graph->IsFlattened());
@@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
// The called computation is only called from one other computation, but there
@@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(3, call_graph->nodes().size());
@@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
// Verify NearestAncestorsInSameComputation for various instructions in the
@@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Test visitation of only reachable nodes.
{
@@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
auto module = CreateNewModule();
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 5d85a3f173..e6b5665435 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -40,7 +40,7 @@ namespace {
// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
-using CallInlinerTest = HloTestBase;
+using CallInlinerTest = HloVerifiedTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
@@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
@@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, pred, "param"));
+ call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
@@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
@@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index e5a6c28478..96bd2616f5 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(instance.computation, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 0826380f65..0ac4a65ec6 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
- auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+ auto zero = add(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
auto new_filter = add(
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 2368ac8c6a..8cc522a59e 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -122,7 +122,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
@@ -801,6 +801,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -822,6 +823,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -946,6 +948,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -971,6 +974,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 05792795a1..2083f440fd 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -32,7 +32,7 @@ namespace cpu {
using ::testing::ElementsAre;
-class ConvCanonicalizationTest : public HloTestBase {
+class ConvCanonicalizationTest : public HloVerifiedTestBase {
public:
ConvCanonicalizationTest() {
for (int i = 0; i < 2; ++i) {
@@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
const HloInstruction* output_reshape = entry_computation->root_instruction();
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
@@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index e7b6075994..18fc144efe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -77,12 +77,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index 4db7fa446e..c9fb34be1c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
return count;
}
-class CpuCopyInsertionTest : public HloTestBase {
+class CpuCopyInsertionTest : public HloVerifiedTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
@@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*module), 3);
@@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
index 0f463e6de6..be1208fb2d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class CpuHloSupportCheckerTest : public HloTestBase {
+class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 942e2ddd39..55d5925642 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -37,21 +37,20 @@ int main(int argc, char** argv) {
xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
// Transfer parameters.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
// Build computation.
xla::XlaBuilder builder("");
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
@@ -59,17 +58,16 @@ int main(int argc, char** argv) {
// Execute and transfer result of computation.
xla::ExecutionProfile profile;
- xla::StatusOr<std::unique_ptr<xla::Literal>> result =
- client->ExecuteAndTransfer(
- computation,
- /*arguments=*/{param0_data.get(), param1_data.get()},
- /*execution_options=*/nullptr,
- /*execution_profile=*/&profile);
- std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+ xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer(
+ computation,
+ /*arguments=*/{param0_data.get(), param1_data.get()},
+ /*execution_options=*/nullptr,
+ /*execution_profile=*/&profile);
+ xla::Literal actual = result.ConsumeValueOrDie();
LOG(INFO) << absl::StrFormat("computation took %dns",
profile.compute_time_ns());
- LOG(INFO) << actual->ToString();
+ LOG(INFO) << actual.ToString();
return 0;
}
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index 7d8e51f909..1a3d82de95 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace cpu {
namespace {
-class ShapePartitionAssignerTest : public HloTestBase {
+class ShapePartitionAssignerTest : public HloVerifiedTestBase {
protected:
typedef std::vector<int64> Vec;
@@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
expected_partitions);
}
-class ShapePartitionIteratorTest : public HloTestBase {
+class ShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
};
@@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
}
}
-class RandomShapePartitionIteratorTest : public HloTestBase {
+class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
RandomShapePartitionIteratorTest()
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index f11aff0573..c55206eee7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -48,6 +48,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 22721051e5..1deb412064 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@@ -34,7 +34,7 @@ namespace xla {
namespace cpu {
namespace {
-class CpuFusionTest : public HloTestBase {
+class CpuFusionTest : public HloVerifiedTestBase {
protected:
CpuFusionTest() {}
@@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
- Shape vshape = input_literal1->shape();
+ Shape vshape = input_literal1.shape();
auto input1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal1)));
@@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -122,11 +122,10 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
- error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
@@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
<< fusion_instruction2->fused_instructions_computation()->ToString();
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
- *result, error_spec_);
+ result, error_spec_);
}
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// Run fusion.
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
auto fusion1 = result->operand(0);
auto fusion2 = result->operand(1);
@@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The only fusion instruction should be operand 0 of the tuple (formerly
// negate1).
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c35569c661..5cc6d01c0f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
// Send 5 Infeed data of shape F32[3].
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 3 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
}
// Tests two Infeed operations with a total order. The order is enforced by
@@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
+ LiteralUtil::CreateR0<bool>(false)})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
+ LiteralUtil::CreateR0<bool>(false)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 6 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index bb105194f1..7af51db55a 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 1b3be199f6..852f34e06d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -56,9 +56,9 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
- RunTest(hlo_text, {lhs.get(), rhs.get()});
+ Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {&lhs, &rhs});
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f6608241e..5fbd73a536 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(7, module->computation_count());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, module->computation_count());
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// The true and false computations must now be different.
EXPECT_EQ(3, module->computation_count());
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 4ed91ef187..bec02e14f9 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
const auto subliteral = LiteralSlice(literal, index);
- std::unique_ptr<Literal> relayed_out_literal;
+ Literal relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
@@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
- source = relayed_out_literal->untyped_data();
+ source = relayed_out_literal.untyped_data();
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream,
/*size=*/GetByteSizeRequirement(device_subshape), source,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6791e15ee0..64b9683628 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -173,6 +174,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
@@ -370,6 +372,8 @@ cc_library(
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
deps = [
+ ":backend_configs",
+ ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -395,6 +399,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -813,9 +818,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
)
@@ -832,6 +837,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -901,6 +907,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 05448d863d..3a23ac1d63 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
@@ -30,62 +31,32 @@ namespace gpu {
using se::dnn::AlgorithmDesc;
-ConvolutionThunk::ConvolutionThunk(
- CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
- const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
- int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
- : Thunk(Kind::kConvolution, hlo),
- convolution_kind_(convolution_kind),
- input_buffer_(input_buffer),
- filter_buffer_(filter_buffer),
- output_buffer_(output_buffer),
- tuple_result_buffer_(tuple_result_buffer),
- scratch_buffer_(scratch_buffer),
- input_shape_(input_shape),
- filter_shape_(filter_shape),
- output_shape_(output_shape),
- window_(window),
- dim_nums_(dim_nums),
- feature_group_count_(feature_group_count),
- algorithm_(algorithm),
- tensor_ops_enabled_(tensor_ops_enabled) {}
-
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- se::DeviceMemoryBase input_data =
- buffer_allocations.GetDeviceAddress(input_buffer_);
- se::DeviceMemoryBase filter_data =
- buffer_allocations.GetDeviceAddress(filter_buffer_);
- se::DeviceMemoryBase output_data =
- buffer_allocations.GetDeviceAddress(output_buffer_);
+ CudnnConvParams params;
+
+ params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
+ params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- se::dnn::AlgorithmConfig algorithm_config(
- se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(
- convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_,
- feature_group_count_, algorithm_config, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
void* result_ptr = [&] {
- switch (convolution_kind_) {
+ switch (params.kind) {
case CudnnConvKind::kForward:
- return output_data.opaque();
+ return params.output_buf.opaque();
case CudnnConvKind::kBackwardInput:
- return input_data.opaque();
+ return params.input_buf.opaque();
case CudnnConvKind::kBackwardFilter:
- return filter_data.opaque();
+ return params.filter_buf.opaque();
}
}();
void* ptrs[] = {result_ptr, scratch.opaque()};
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 68d67c40c5..d7d1f91fba 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,7 +33,7 @@ limitations under the License.
namespace xla {
namespace gpu {
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
// convolution. It is generated by IrEmitter.
//
// This is thread-compatible.
@@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that
- // we should use the default (i.e. baseline) cudnn algorithm.
- //
// Note that "output" here doesn't refer to the output from running this
// thunk, but rather to the "output" of a hypothetical forward convolution
// that corresponds to this input+filter+output triple. That is, the result
// generated by this thunk is "output" for forward convs, "input" for
// backward-input convs, and "filter" for backward-filter convs.
- //
- // Semantics of null hlo_instruction argument are as in Thunk.
- ConvolutionThunk(CudnnConvKind convolution_kind,
- const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer,
- const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums,
- int64 feature_group_count, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo);
+ ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+ BufferAllocation::Slice input_slice,
+ BufferAllocation::Slice filter_slice,
+ BufferAllocation::Slice output_slice,
+ BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ input_buffer_(std::move(input_slice)),
+ filter_buffer_(std::move(filter_slice)),
+ output_buffer_(std::move(output_slice)),
+ scratch_buffer_(std::move(scratch_slice)),
+ tuple_result_buffer_(std::move(tuple_result_slice)) {}
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- const CudnnConvKind convolution_kind_;
-
- const BufferAllocation::Slice input_buffer_;
- const BufferAllocation::Slice filter_buffer_;
- const BufferAllocation::Slice output_buffer_;
- const BufferAllocation::Slice tuple_result_buffer_;
- const BufferAllocation::Slice scratch_buffer_;
-
- const Shape input_shape_;
- const Shape filter_shape_;
- const Shape output_shape_;
-
- const Window window_;
- const ConvolutionDimensionNumbers dim_nums_;
- int64 feature_group_count_;
- int64 algorithm_;
- bool tensor_ops_enabled_;
+ const HloCustomCallInstruction* cudnn_call_;
+ BufferAllocation::Slice input_buffer_;
+ BufferAllocation::Slice filter_buffer_;
+ BufferAllocation::Slice output_buffer_;
+ BufferAllocation::Slice scratch_buffer_;
+ BufferAllocation::Slice tuple_result_buffer_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 5c2555148a..f528e62b17 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/mutex.h"
@@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- HloInstruction* instr) {
+ const HloCustomCallInstruction* instr) {
+ CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
+
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+
CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
CHECK_EQ(input_shape.element_type(), output_shape.element_type());
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -216,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
allocator = &*se_allocator;
}
- // Allocate space for the input, filter, and output of the convolution. We
- // use a ScratchAllocator for this instead of calling allocator_ directly so
- // that our allocations don't leak.
- ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- if (cross_check_enabled) {
- // Broadcast a constant to the buffer, instead of zeroing the buffer. A
- // non-zero constant is useful for the cross checking, because zero-inputs
- // may not always reveal the bugs.
- const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+ const auto initialize_buffer = [&stream, cross_check_enabled](
+ DeviceMemoryBase buffer) {
+ if (cross_check_enabled) {
+ // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+ // non-zero constant is useful for the cross checking, because zero-inputs
+ // may not always reveal the bugs.
CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
size_t left_over_bytes = buffer.size() % 4;
CHECK_EQ(0, left_over_bytes % 2);
@@ -252,33 +244,46 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
DeviceMemoryBase left_over(
static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
- };
- initialize_f16(input_buf);
- initialize_f16(filter_buf);
- initialize_f16(output_buf);
- } else {
- // Although we don't have evidence this matters, zero out the buffers before
- // autotuning. It's conceivable that using uninitialized memory as the
- // inputs might affect performance if e.g. the inputs contain denormals, and
- // this is easy enough.
- stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size());
- }
+ } else {
+ // Although we don't have evidence this matters, zero out the buffers
+ // before autotuning. It's conceivable that using uninitialized memory as
+ // the inputs might affect performance if e.g. the inputs contain
+ // denormals, and this is easy enough.
+ stream.ThenMemZero(&buffer, buffer.size());
+ }
+ };
+
+ // Allocate space for the input, filter, and output of the convolution. We
+ // use a ScratchAllocator for this instead of calling allocator_ directly so
+ // that our allocations don't leak.
+ ScratchAllocator input_output_allocator(device_ordinal, allocator);
+ TF_ASSIGN_OR_RETURN(params.input_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(input_shape)));
+ TF_ASSIGN_OR_RETURN(params.filter_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+ TF_ASSIGN_OR_RETURN(params.output_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(output_shape)));
+
+ initialize_buffer(params.input_buf);
+ initialize_buffer(params.filter_buf);
+ initialize_buffer(params.output_buf);
DeviceMemoryBase* result_buf = [&] {
- switch (kind) {
+ switch (params.kind) {
case CudnnConvKind::kBackwardFilter:
- return &filter_buf;
+ return &params.filter_buf;
case CudnnConvKind::kBackwardInput:
- return &input_buf;
+ return &params.input_buf;
case CudnnConvKind::kForward:
- return &output_buf;
+ return &params.output_buf;
}
}();
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, dnums, stream_exec_);
+ input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
@@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
for (const AlgorithmDesc& alg :
- GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
+ GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- bool launch_ok =
- RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf,
- filter_buf, output_buf, &scratch_allocator, window, dnums,
- feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
- .ok();
+ params.algorithm = AlgorithmConfig(alg);
+ bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
+ &profile_result)
+ .ok();
if (launch_ok && profile_result.is_valid()) {
const bool crash_on_checking_failure =
@@ -374,34 +377,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- const auto& call_target = instr->custom_call_target();
- const auto& lhs_shape = instr->operand(0)->shape();
- const auto& rhs_shape = instr->operand(1)->shape();
- const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
- if (call_target == kCudnnConvForwardCallTarget) {
- alg_scratch_and_tc =
- PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, instr->window(),
- instr->convolution_dimension_numbers(),
- instr->feature_group_count(), instr);
- } else if (call_target == kCudnnConvBackwardInputCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr->feature_group_count(),
- instr);
- } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
- instr->window(), instr->convolution_dimension_numbers(),
- instr->feature_group_count(), instr);
- } else {
- LOG(FATAL) << "Unknown custom call target for cudnn conv: "
- << instr->ToString();
- }
+ StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
if (!alg_scratch_and_tc.ok()) {
LOG(ERROR) << alg_scratch_and_tc.status();
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 0cb01161b0..f79b113f8f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- HloInstruction* instr);
+ const HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 9bf721ecd2..228379a248 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include <cstdlib>
#include <numeric>
#include <vector>
@@ -59,8 +60,6 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward filter.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -218,13 +217,16 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
// Try to match a backward input pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
-std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
- HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
+MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
- std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward input.
+ // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also
+ // for the backward input convolution, but at least for now with version 7.1.4
+ // it is slower. This needs to be re-evaluated for future cuDNN versions.
+ // Note that we already have the necessary code down below, the only thing to
+ // enable it is to remove the following early return.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -232,51 +234,38 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
-
- // Match the reverse of the filter.
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
- const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
- if (reverse_filter->opcode() == HloOpcode::kReverse) {
- if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
- !std::is_permutation(kernel_spatial_dims.begin(),
- kernel_spatial_dims.end(),
- reverse_filter->dimensions().begin())) {
- VLOG(1)
- << "Backward input convolution should reverse all kernel dimensions.";
- return no_match_result;
- }
- } else if (reverse_filter->IsConstant()) {
- // If the filter is a constant, we're willing to pattern-match to a
- // backwards-input conv, on the theory that
- //
- // a) reversing a constant is free, and
- // b) even if the user specified this filter as reverse(constant), we would
- // long ago have constant-folded away the reverse.
- //
- // If the constant has any other uses, reversing it isn't entirely free,
- // since we'd now have two constants to keep in memory. But hopefully it's
- // free enough.
- //
- // TODO(jlebar): Should we do this even if the filter is not a constant?
- // Reversing a non-constant filter is probably cheaper than padding the
- // input!
-
- // Nothing to do, just fall through.
- } else {
- // Possibly 1x1 filter.
- for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->window().dimensions(i).size() != 1) {
- VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
- << reverse_filter->ToString();
- return no_match_result;
- }
- }
- if (!window_util::HasBaseDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward input convolution.";
- return no_match_result;
- }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
+ "constant filter.";
+ return no_match_result;
}
// Match padding and dilation of the forward convolution.
@@ -401,26 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // OK, it's a match! Canonicalize the conv's filter so that it's a reverse.
- // This simplifies things for our caller, and algebraic-simplifier will later
- // remove any unnecessary reverses.
- if (reverse_filter->opcode() != HloOpcode::kReverse) {
+ // OK, it's a match! Switch the input feature dimension with the output
+ // feature dimension. This is the way cuDNN expects it to be.
+ dnums.set_kernel_input_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+ dnums.set_kernel_output_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+
+ // If we matched against a constant, we need to add a reverse op that can be
+ // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+ // unnecessary reverses.
+ if (reverse_filter->opcode() != HloOpcode::kReverse &&
+ reverse_filter->IsConstant()) {
// Create a double-reverse, which is a nop.
HloComputation* c = conv->parent();
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}
- dnums.set_kernel_input_feature_dimension(
- conv->convolution_dimension_numbers().kernel_output_feature_dimension());
- dnums.set_kernel_output_feature_dimension(
- conv->convolution_dimension_numbers().kernel_input_feature_dimension());
- return std::make_tuple(true, new_window, dnums);
+ // Calculate the 'rhs' that goes into the backward input convolution.
+ HloInstruction* rhs = reverse_filter;
+ // One reverse is subsumed by the cuDNN call.
+ if (rhs->opcode() == HloOpcode::kReverse) {
+ rhs = rhs->mutable_operand(0);
+ }
+ if (conv->feature_group_count() == 1) {
+ return std::make_tuple(true, new_window, dnums, rhs);
+ }
+
+ // Handle grouped convolutions. Because we swapped the input feature dimension
+ // with the output feature dimension, we need to also reshape the kernel so
+ // that the 'feature_group_count' parameter still makes sense. The
+ // 'feature_group_count' parameter essentially specifies how often the
+ // 'kernel_input_feature_dimension' is repeated. So when we swap these
+ // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+ // 'feature_group_count' and multiply the new
+ // 'kernel_output_feature_dimension' by 'feature_group_count'.
+ Shape new_shape = rhs->shape();
+ int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
+
+ // In the backward convolution case, the spatial dimensions become the
+ // feature dimensions, and we are guaranteed that the spatial dimensions are
+ // adjacent.
+ CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
+ int64 input_features = new_shape.dimensions(input_feature_dimension);
+ int64 output_features = new_shape.dimensions(output_feature_dimension);
+ new_shape.set_dimensions(input_feature_dimension,
+ input_features / conv->feature_group_count());
+ new_shape.set_dimensions(output_feature_dimension,
+ output_features * conv->feature_group_count());
+ HloComputation* c = conv->parent();
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+ return std::make_tuple(true, new_window, dnums, rhs);
}
// Tries to rewrite a single convolution into a call to cudnn.
@@ -431,6 +458,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
+ HloInstruction* rhs;
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
@@ -439,13 +467,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
window, dnums, conv->feature_group_count());
}
- std::tie(match, window, dnums) = MatchBackwardInput(conv);
+ std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- // Backward input conv subsumes the conv plus the reverse in operand 1.
- HloInstruction* reverse = conv->mutable_operand(1);
- CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
- HloInstruction* rhs = reverse->mutable_operand(0);
-
return CreateCudnnConvBackwardInput(conv->shape(),
conv->mutable_operand(0), rhs, window,
dnums, conv->feature_group_count());
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index bda8ebe579..d237f8930b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
Array4D<float> constant_arr(4, 4, 2, 2);
constant_arr.FillIota(0);
string constant_str =
- LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+ LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
ParseAndVerifyModule(absl::StrFormat(R"(
HloModule test
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 05125e9d1f..2a86ac265e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator {
};
template <typename T>
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<T> input_buf,
- DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- AlgorithmConfig algorithm, Stream* stream,
- ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ CudnnConvKind kind = params.kind;
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+ DeviceMemory<T> input_buf(params.input_buf);
+ DeviceMemory<T> filter_buf(params.filter_buf);
+ DeviceMemory<T> output_buf(params.output_buf);
+ const Window& window = *params.window;
+ const ConvolutionDimensionNumbers& dnums = *params.dnums;
+ int64 feature_group_count = params.feature_group_count;
+ AlgorithmConfig algorithm = params.algorithm;
+
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) {
}
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
- output_buf, &scratch_allocator, window, dnums, feature_group_count,
- algorithm, stream, profile_result);
+ return RunCudnnConvolution(params, &scratch_allocator, stream,
+ profile_result);
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = params.output_shape->element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+ stream, profile_result);
case F32:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+ profile_result);
case F64:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+ profile_result);
default:
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index a1b4fc71d0..381aa37a1b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -47,6 +47,20 @@ enum class CudnnConvKind {
kBackwardFilter, // input + output => filter
};
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// Converts a CudnnConvKind value to a string.
string CudnnConvKindToString(CudnnConvKind kind);
@@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
// Note that depending on the value of CudnnConvKind, the result of this call
// may be written into input_buf, filter_buf, or output_buf!
//
-// At the moment we only support cudnn convolutions over float and half, and
-// convolution with half data type is implemented with cudnn PSEUDO_HALF
-// configuration, that is, the input values are half and the internal
-// computation type is float.
+// At the moment convolution with half data type is implemented with cudnn
+// PSEUDO_HALF configuration, that is, the input values are half and the
+// internal computation type is float.
//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
@@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index ea9376e101..02a0d028c1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 59ade96f7d..b857fa775a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -24,14 +24,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class GpuHloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloVerifiedTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
index 0a4089df4c..27a4d0b601 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class GpuHloSupportCheckerTest : public HloTestBase {
+class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 20d523abe0..22f43bc08b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params) {
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ custom_call->backend_config<CudnnConvBackendConfig>());
+ const auto& target = custom_call->custom_call_target();
+ const auto& lhs_shape = custom_call->operand(0)->shape();
+ const auto& rhs_shape = custom_call->operand(1)->shape();
+ const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+
+ params->window = &custom_call->window();
+ params->dnums = &custom_call->convolution_dimension_numbers();
+ params->feature_group_count = custom_call->feature_group_count();
+ params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params->kind = CudnnConvKind::kForward;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &conv_result_shape;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params->kind = CudnnConvKind::kBackwardInput;
+ params->input_shape = &conv_result_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &lhs_shape;
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params->kind = CudnnConvKind::kBackwardFilter;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &conv_result_shape;
+ params->output_shape = &rhs_shape;
+ } else {
+ LOG(FATAL) << "Unexpected custom call target: "
+ << custom_call->custom_call_target();
+ }
+ return Status::OK();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 59c65fc268..09c455cc1e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,9 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
// don't belong in "ir_emission_utils".
@@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
+// Populates params using conv, which must be a custom-call to a cudnn
+// convolution. Does not modify any buffers in the params.
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f91cc00d71..b669881026 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
const auto& target = custom_call->custom_call_target();
- std::unique_ptr<ConvolutionThunk> thunk;
+ BufferAllocation::Slice input_slice, filter_slice, output_slice;
+
if (target == kCudnnConvForwardCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kForward,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/conv_result_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = rhs_slice;
+ output_slice = conv_result_slice;
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardInput,
- /*input_buffer=*/conv_result_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/lhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/lhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = conv_result_slice;
+ filter_slice = rhs_slice;
+ output_slice = lhs_slice;
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardFilter,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/conv_result_slice,
- /*output_buffer=*/rhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape,
- /*output_shape=*/rhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = conv_result_slice;
+ output_slice = rhs_slice;
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
}
- thunk_sequence_->emplace_back(std::move(thunk));
+ thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
+ Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
+ output_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index f6325b3368..dfdcf1875d 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
- // CudnnConvolutionRewriter may add instructions of the form
- // reverse(constant), which it expects will be simplified by constant
- // folding.
- pipeline.AddPass<HloConstantFolding>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// pairs that TupleSimplifier fixes.
pipeline.AddPass<TupleSimplifier>();
}
+ // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add
+ // instructions which can be simplified by constant folding.
+ pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index fa84d77223..b0061fa655 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
static constexpr int64 kDesiredNumFeaturesFactor = 8;
@@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
HloComputation* comp = instr->parent();
const Shape& shape = instr->shape();
- auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+ auto* zero = comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 9d85d746d8..2a6415d0b6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
conv_window.dimensions(i).base_dilation() - 1);
}
PrimitiveType element_type = input->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloComputation* computation = kernel->parent();
PrimitiveType element_type = kernel->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(input->shape().element_type()))));
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(input->shape().element_type())));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 8f0dedfa40..c4f43cc9a6 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -21,14 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class StreamAssignmentTest : public HloTestBase {
+class StreamAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index 4550f36fdf..780539c164 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
TEST_F(GpuCopyTest, UseMemcpy) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
builder.AddInstruction(HloInstruction::CreateUnary(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index 9072b30317..f8120a5fa0 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
@@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
TEST_F(InfeedTest, LargeInfeed) {
Array4D<float> array(80, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
+ TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
Array4D<float> array(40, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
- LiteralUtil::CreateR0<int32>(5).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR4FromArray4D<float>(array),
+ LiteralUtil::CreateR0<int32>(5)}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 40183de96e..9a61f8ac5a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -26,9 +26,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::testing::Eq;
-using ::testing::HasSubstr;
-
class WhileTransformerTest : public HloTestBase {
protected:
WhileTransformerTest()
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 00a25db467..957c4a6891 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
-class MinimumMemoryForSequenceTest : public HloTestBase {};
+class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
auto module = CreateNewModule();
@@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- HloSchedule schedule(module.get());
+ HloSchedule schedule(module);
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_data, cond_lt});
schedule.set_sequence(body_computation, {body_param});
@@ -233,7 +233,7 @@ class HeapSimulatorTracker {
HeapSimulator::Result result_;
};
-class HeapSimulatorTest : public HloTestBase {
+class HeapSimulatorTest : public HloVerifiedTestBase {
protected:
HeapSimulatorTest() {}
~HeapSimulatorTest() override {}
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 93ec2c9438..b19ec12638 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -309,6 +309,13 @@ message HeapSimulatorTrace {
bool whole_module_simulation = 2;
}
+// An abstraction representing a set of HLO module built to run concurrently
+// across different devices.
+message HloModuleGroupProto {
+ string name = 1;
+ repeated HloModuleProto hlo_modules = 2;
+}
+
// Serialization of BufferAssignment.
message BufferAssignmentProto {
// Alias represents a source LogicalBuffer, and the buffer location that
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 233d2199d1..8c6903d766 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ auto computation = absl::WrapUnique(
+ new HloComputation(proto.name(), parameter_count, &instructions, root,
+ /*fusion_instruction=*/nullptr));
+ computation->unique_id_ = proto.id();
+ return std::move(computation);
}
void HloComputation::FuseInstructionsInto(
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 8a45939c61..f837816cea 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
- std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+ Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
- if (result == nullptr) {
+ if (!evaluator->TryEvaluate(instruction, &result)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 07cd1efc12..3e0def5d26 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using HloConstantFoldingTest = HloTestBase;
+using HloConstantFoldingTest = HloVerifiedTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
@@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->Literal::CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+ matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}
@@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"(
})";
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
+ ParseAndVerifyModule(kConstantFoldReduce);
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_TRUE(result);
- EXPECT_EQ(6, module->entry_computation()
+ EXPECT_EQ(6, module()
+ .entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
- HloInstruction* add = module->computations().begin()->root_instruction();
+ ParseAndVerifyModule(kConstantFoldReduce);
+ HloInstruction* add = module().computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_FALSE(result);
- EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+ EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index a3fcc0fefa..b76c50bb5b 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(operand->shape().element_type()))));
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(operand->shape().element_type())));
return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index eb6affadc8..e07a196d11 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
entry_computation->set_root_instruction(first_1_dims_collapsed);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module,
{LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<int32>(
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
entry_computation->set_root_instruction(zero_padded_param);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index e09d5868f2..9b18b0284f 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index d0d955fea8..06b6d5b559 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -54,9 +54,8 @@ namespace xla {
namespace {
template <typename OperandT>
-StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
}));
@@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
template <>
-StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
@@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
@@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
- .CloneToUnique();
+ .Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloModule& module, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(module, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
@@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
}
TF_RETURN_IF_ERROR(computation.Accept(this));
- return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
+ return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloComputation& computation, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(computation, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
@@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
- evaluated_[operand] = input_literal->CloneToUnique();
+ evaluated_[operand] = input_literal->Clone();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal : arg_literals) {
+ arg_literal_ptrs.push_back(&literal);
+ }
+ return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction) {
+StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
@@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
}
-std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
- HloInstruction* instruction) {
+bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
+ CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
- return nullptr;
+ return false;
}
- return result_or.ConsumeValueOrDie();
+ *result = result_or.ConsumeValueOrDie();
+ return true;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
+StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
@@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
- HloInstruction::CreateConstant(it->second->CloneToUnique()));
+ HloInstruction::CreateConstant(it->second->Clone()));
}
}
@@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
- HloInstruction::CreateConstant(operand.CloneToUnique());
+ HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
@@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
- evaluated_[parameter] = input_literal->CloneToUnique();
+ evaluated_[parameter] = input_literal->Clone();
return Status::OK();
}
@@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
@@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64 index_vector_dim, const Literal& start_indices,
- std::unique_ptr<Literal>* reshaped_start_indices) {
+ Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
@@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
start_indices.Reshape(new_shape));
- return std::cref(**reshaped_start_indices);
+ return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
- std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+ Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_start_indices;
+ Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
@@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
- result->CopyElementFrom(operand, input_index, output_index));
+ result.CopyElementFrom(operand, input_index, output_index));
return true;
};
@@ -940,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
// Checks that operand's dimensions are the same as the broadcast's
// dimensions along the dimensions to be broadcasted.
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand.shape().dimensions(i));
+ auto operand_dim_size = operand.shape().dimensions(i);
+ auto broadcast_dim_size =
+ broadcast->shape().dimensions(broadcast->dimensions(i));
+ TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
+ "Operand dimension %d is broadcast to output dimension %d, but the "
+ "sizes of these two dims do not match (%d vs %d): %s",
+ i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
+ broadcast->ToString());
}
TF_ASSIGN_OR_RETURN(
@@ -971,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = absl::make_unique<Literal>(
- ShapeUtil::GetTupleElementShape(operand->shape(), index));
- return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
- /*dest_shape_index=*/{},
- /*src_shape_index=*/{index});
+ evaluated_[get_tuple_element] =
+ Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
+ return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
+ /*dest_shape_index=*/{},
+ /*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
-
- auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
- evaluated_[copy] = std::move(result);
+ evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
@@ -998,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
.ConsumeValueOrDie();
@@ -1030,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator
.Evaluate<const Literal*>(*readded_computation, arg_literals)
.ConsumeValueOrDie();
@@ -1050,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
auto* false_computation = conditional->false_computation();
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result;
+ Literal result;
if (pred.Get<bool>({})) {
result = embedded_evaluator
.Evaluate<const Literal*>(*true_computation,
@@ -1075,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
- evaluated_[select] = on_true.CloneToUnique();
+ evaluated_[select] = on_true.Clone();
} else {
- evaluated_[select] = on_false.CloneToUnique();
+ evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
@@ -1091,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
- evaluated_[tuple_select] = on_true.CloneToUnique();
+ evaluated_[tuple_select] = on_true.Clone();
} else {
- evaluated_[tuple_select] = on_false.CloneToUnique();
+ evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
@@ -1102,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
- auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
+ auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
@@ -1112,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
while_hlo->name(), max_loop_iterations_);
}
- TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
- *cond_comp, {lcv.get()}));
- keep_going = cond_val->GetFirstElement<bool>();
+ TF_ASSIGN_OR_RETURN(auto cond_val,
+ cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
+ keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
- *body_comp, {lcv.get()}));
- VLOG(3) << "Loop iteration result: " << body_val->ToString();
+ *body_comp, {&lcv}));
+ VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
@@ -1133,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
// hoops to make this work.
namespace {
template <typename KeyType, typename ValueType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@@ -1173,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
- auto result_values_literal =
- absl::make_unique<Literal>(values_literal.shape());
- result_values_literal->PopulateR1(
+ Literal result_keys_literal(keys_literal.shape());
+ result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal result_values_literal(values_literal.shape());
+ result_values_literal.PopulateR1(
absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
- std::unique_ptr<Literal> result_tuple;
+ Literal result_tuple;
if (rank == 1) {
auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple = LiteralUtil::MakeTuple(
- {result_pair.first.get(), result_pair.second.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
- auto values_result_literal =
- absl::make_unique<Literal>(values_literal.shape());
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
+ .Reshape({r1_length}));
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
values_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+ .Reshape({r1_length}));
+ auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first->Reshape({1, r1_length}));
+ r1_result_pair.first.Reshape({1, r1_length}));
TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
- *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
- *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+ r1_result_pair.second.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
}
- result_tuple = LiteralUtil::MakeTuple(
- {keys_result_literal.get(), values_result_literal.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
}
- VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
template <typename KeyType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@@ -1242,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
}
}
-StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
- const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(0)->shape().element_type()) {
case F32:
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@@ -1308,33 +1339,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) {
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
VLOG(2) << "Finished visiting " << hlo->ToString()
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
+ // Out of convenience the literal may have been produced with a different
+ // layout. Relayout as indicated by the HLO instruction.
+ if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
+ hlo->shape())) {
+ evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
+ }
return Status::OK();
}
// Explicit instantiation of templatized Evaluate* methods.
//
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloModule& module, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
-
-template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
- const Literal*>(const HloComputation& computation,
- absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
+
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloComputation& computation,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- HloInstruction* instruction,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 72252bafc7..21e676d671 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Precondition: The indices of arg_literals correspond to the parameter
// numbers of the HLO parameters in the computation. See comment below for an
// example.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloModule& module,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
// 1 in this computation. The input literals array will then have its first
// literal map to Parameter0 and the second map to Parameter1.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloComputation& computation,
- absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
- StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction);
- // Same as Evaluate, except returning nullptr on error.
- std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+ // Same as Evaluate, except returning false on error and accepts an output
+ // pointer.
+ bool TryEvaluate(HloInstruction* instruction, Literal* result);
// Evaluates a single HLO instruction, substituting the given literals for
// some of the instruction's operands.
//
// For example, given instruction = op(A, B, C) and the map
// {A = x, C = y}, this evaluates op(x, B, y).
- StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
+ StatusOr<Literal> EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
- HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+ StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
+ const Literal& lhs,
+ const Literal& rhs);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
- HloOpcode opcode, const Literal& operand);
+ StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
+ const Literal& operand);
- StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers,
- const PrecisionConfig& precision_config, const Literal& lhs,
- const Literal& rhs);
+ StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
+ const Literal& lhs, const Literal& rhs);
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
- return *(it->second);
+ return it->second;
}
// Tracks the HLO instruction and its evaluated literal result.
@@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// that are no longer a parent for any other subsequent instruction in
// post-orderring.
// Must be cleared for each evaluation.
- tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
- evaluated_;
+ // Storing Literal in place require the container to have pointer stability so
+ // we cannot use FlatMap any more.
+ std::unordered_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
- static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ static StatusOr<Literal> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
@@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
ShapeUtil::HumanString(operand->shape()));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 102ebb24ab..01e88566a5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
evaluator_ = absl::make_unique<HloEvaluator>();
}
- std::unique_ptr<Literal> Evaluate(
- absl::Span<const Literal* const> arg_literals = {}) {
+ Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<HloEvaluator> evaluator_;
- void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> input, float aabs = 0) {
+ void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
+ float aabs = 0) {
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
- b.AddInstruction(
- HloInstruction::CreateUnary(expected->shape(), opcode, c1));
+ b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto element_type = expected->shape().element_type();
+ auto element_type = expected.shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
}
- void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> lhs,
- std::unique_ptr<Literal> rhs) {
+ void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
+ Literal rhs) {
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
b.AddInstruction(
- HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
+ HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
bool use_bfloat16_;
@@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- Shape shape = low->shape();
+ Shape shape = low.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
auto high = LiteralUtil::CreateR0<float>(1.f);
- Shape shape = value->shape();
+ Shape shape = value.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- Shape shape = on_true->shape();
+ Shape shape = on_true.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
auto c2 =
@@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
- std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
+ std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
lhs_instruction, param_rhs2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate(args);
+ Literal result = Evaluate(args);
auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies Reshape operation is correctly evaluated.
@@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
@@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
- result->EachCell<NativeT>(
- [&](absl::Span<const int64> indices, NativeT value) {
- std::vector<int64> rindexes = Permute(permutation, indices);
- EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
- });
+ result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
+ });
}
// Verifies Broadcast operation is correctly evaluated.
@@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction, {1, 2}));
+ output_literal.shape(), literal_instruction, {1, 2}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloInstruction::CreateConstant(std::move(input_literal)));
// Broadcast dimension should be empty in the case of scalars.
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction,
+ output_literal.shape(), literal_instruction,
/*broadcast_dimensions=*/{}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int64>(
{{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<int64>({100, 200});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
- ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
- ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
PaddingConfig CreatePaddingConfig(
@@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
shape, operand_instruction, padding_value_instruction, padding_config));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
@@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
@@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected_array = Array2D<float>({
@@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = Array2D<float>({
{22.f, 28.f},
@@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 4, 4);
// clang-format off
@@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 7, 7);
expected_array.FillWithYX(Array2D<float>({
@@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 8, 8);
expected_array.FillWithYX(Array2D<float>({
@@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest,
@@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 9, 3);
expected_array.FillWithYX(Array2D<float>({
@@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest,
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
@@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
std::iota(input_elems.begin(), input_elems.end(), -7);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
std::iota(filter_elems.begin(), filter_elems.end(), -31);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
@@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
/*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 1, 8);
expected_array.FillWithYX(
Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
module().AddEntryComputation(b.Build());
HloEvaluator hlo_eval;
- std::unique_ptr<Literal> result =
- hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+ Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
}
// Reducing many numbers should be fast because it doesn't create
@@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({6, 18});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
@@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
/*strides=*/{2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
@@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
shape, operand, update, start_indices));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto result_inner_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
- auto expected = LiteralUtil::MakeTuple({
- result_inner_literal.get(),
- result_inner_literal.get(),
- });
+ auto expected =
+ LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Reverse) {
@@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected = LiteralUtil::CreateR4FromArray4D<float>({
@@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
+ Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, &param0_literal}, {square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
- auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
+ auto result =
+ evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1906,12 +1901,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1930,12 +1925,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1954,14 +1949,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>(
+ LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1980,15 +1974,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest,
@@ -2008,15 +2001,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -2035,12 +2027,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -2059,13 +2050,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2084,11 +2074,10 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2108,12 +2097,12 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> start_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2138,15 +2127,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
@@ -2171,15 +2158,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
@@ -2205,15 +2191,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
@@ -2239,15 +2223,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) {
@@ -2273,17 +2255,15 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(
+ LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()}),
- ErrorSpec{0.1, 0.01}));
+ Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01}));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
@@ -2309,15 +2289,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
@@ -2343,15 +2321,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
@@ -2376,21 +2353,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
{{-40, 40}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest,
@@ -2416,21 +2390,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
@@ -2455,16 +2426,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
@@ -2489,17 +2458,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
@@ -2524,13 +2490,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *operand,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ operand, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
@@ -2557,16 +2521,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<int32>({10, 61, 32});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -2603,11 +2564,29 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>(
+ Literal arg = LiteralUtil::CreateR1<bfloat16>(
{bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()})));
+ Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
+}
+
+TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) {
+ // Regression test for b/114735354.
+ const string hlo_text = R"(
+HloModule SliceWithDifferentLayout
+
+ENTRY main {
+ arg = f32[2,2,2]{0,1,2} parameter(0)
+ ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ Literal arg = LiteralUtil::CreateR3WithLayout<float>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ LayoutUtil::MakeLayout({0, 1, 2}));
+ Literal actual = Evaluate({&arg});
+ EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
}
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 63303aef1e..8fb17a0033 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = absl::make_unique<Literal>(dot->shape());
+ Literal result(dot->shape());
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ Literal result(pad->shape());
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
@@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
+ result.Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = absl::make_unique<Literal>(map->shape());
+ Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
+ Literal computed_result =
+ embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
- return computed_result->Get<ReturnT>({});
+ return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ Literal result_literal(keys_literal.shape());
+ result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result = sort_r1(*r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
- *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ .Reshape({r1_length}));
+ auto r1_result = sort_r1(r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
- results[i] = absl::make_unique<Literal>(result_shape);
+ results[i] = Literal(result_shape);
}
Status eval_status;
@@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
for (int64 input = 0; input < num_args; ++input) {
- TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
[&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
@@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
// Evaluate computation with specified literal operands.
- absl::InlinedVector<std::unique_ptr<Literal>, 1>
- embedded_operands;
+ absl::InlinedVector<Literal, 1> embedded_operands;
for (ReturnT value : result_values) {
embedded_operands.push_back(
LiteralUtil::CreateR0<ReturnT>(value));
@@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_operands.size());
std::transform(embedded_operands.begin(), embedded_operands.end(),
embedded_operands_ptrs.begin(),
- [](const std::unique_ptr<Literal>& ptr) {
- return ptr.get();
- });
+ [](Literal& literal) { return &literal; });
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ TF_ASSIGN_OR_RETURN(Literal computed_result,
embedded_evaluator.Evaluate<const Literal*>(
*function, embedded_operands_ptrs));
// Clear visit states so that we can use the evaluator again on
@@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
// Assign computed result to result_val.
if (!has_tuple_output) {
- result_values[0] = computed_result->Get<ReturnT>({});
+ result_values[0] = computed_result.Get<ReturnT>({});
} else {
for (int64 i = 0; i < num_args; ++i) {
- result_values[i] = computed_result->Get<ReturnT>(
+ result_values[i] = computed_result.Get<ReturnT>(
/*multi_index=*/{}, /*shape_index=*/{i});
}
}
@@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!has_tuple_output) {
parent_->evaluated_[reduce] = std::move(results[0]);
} else {
- auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ Literal tuple_result(reduce->shape());
for (int64 i = 0; i < num_args; ++i) {
- TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
parent_->evaluated_[reduce] = std::move(tuple_result);
}
@@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(select_and_scatter->shape());
+ Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
@@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
- curr_val_literal->Set({}, curr_val);
- selected_val_literal->Set({}, *selected_val);
- std::unique_ptr<Literal> computed_result =
+ curr_val_literal.Set({}, curr_val);
+ selected_val_literal.Set({}, *selected_val);
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *select,
- {selected_val_literal.get(), curr_val_literal.get()})
+ *select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
+ bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- source_literal_scatter->Set({}, source);
- scattered_literal->Set({}, scattered);
- std::unique_ptr<Literal> computed_result =
+ auto scattered = result.Get<ReturnT>(operand_index);
+ source_literal_scatter.Set({}, source);
+ scattered_literal.Set({}, scattered);
+ Literal computed_result =
embedded_evaluator
- .Evaluate<const Literal*>(*scatter,
- {source_literal_scatter.get(),
- scattered_literal.get()})
+ .Evaluate<const Literal*>(
+ *scatter,
+ {&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce_window->shape());
+ Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
+ *function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
- result_val = computed_result->Get<ReturnT>({});
+ result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
- std::unique_ptr<Literal>* reshaped_indices) {
+ Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
- return std::cref(**reshaped_indices);
+ return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
- std::unique_ptr<Literal> reshaped_scatter_indices;
+ Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
- std::unique_ptr<Literal> result = operand.CloneToUnique();
+ Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
[&](absl::Span<const int64> update_window_index,
@@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
- LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+ LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
- std::unique_ptr<Literal> updated_result =
+ Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
- {result_value_literal.get(), update_value_literal.get()})
+ {&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
- result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+ result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
@@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (ShapeUtil::Rank(iota->shape()) > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
- result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
parent_->evaluated_[iota] = std::move(result);
@@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
+ StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+ const Literal& start_indices_literal,
+ const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
+ StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+ const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
+ const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
- result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+ result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
+ result.Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
return true;
};
@@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 0345a2a5f8..287ba84b3b 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -123,6 +123,10 @@ class NodeFilter {
// We arbitrarily set this as the boundary between "large" and "small"
// instructions.
bool IsSmall(const HloInstruction* instr) {
+ if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
+ ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
+ return true;
+ }
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
}
@@ -465,9 +469,8 @@ stylesheet=<
string graph_label =
StrCat(label_, "<br/>Computation ", computation_->name());
if (computation_->IsFusionComputation()) {
- StrAppend(&graph_label,
- StrCat(" (in fusion instruction ",
- computation_->FusionInstruction()->name(), ")"));
+ StrAppend(&graph_label, " (in fusion instruction ",
+ computation_->FusionInstruction()->name(), ")");
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 25ae344ea5..e905f2983a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
- instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+ instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
@@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -527,7 +528,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
- std::unique_ptr<Literal> literal) {
+ Literal literal) {
return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
@@ -2096,7 +2097,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
- if (!control_predecessors_.empty()) {
+ if (options.print_control_dependencies() && !control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
StrJoin(control_predecessors_, ", ",
[&](string* out, HloInstruction* pre) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5581c17c2d..4f6cac1396 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -82,6 +82,7 @@ class HloPrintOptions {
print_operand_shape_(true),
print_program_shape_(true),
print_percent_(true),
+ print_control_dependencies_(true),
canonicalize_instruction_names_(false),
indent_amount_(0),
is_in_nested_computation_(false) {}
@@ -94,7 +95,8 @@ class HloPrintOptions {
.set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
- .set_print_percent(false);
+ .set_print_percent(false)
+ .set_print_control_dependencies(false);
}
// Options to produce the canonical string representing an isomorphic
@@ -108,6 +110,7 @@ class HloPrintOptions {
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
+ .set_print_control_dependencies(false)
.set_canonicalize_instruction_names(true);
}
@@ -153,6 +156,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, control dependencies will be printed.
+ HloPrintOptions& set_print_control_dependencies(bool value) {
+ print_control_dependencies_ = value;
+ return *this;
+ }
+
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
@@ -190,6 +199,9 @@ class HloPrintOptions {
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
+ bool print_control_dependencies() const {
+ return print_control_dependencies_;
+ }
bool canonicalize_instruction_names() const {
return canonicalize_instruction_names_;
}
@@ -205,6 +217,7 @@ class HloPrintOptions {
bool print_operand_shape_;
bool print_program_shape_;
bool print_percent_;
+ bool print_control_dependencies_;
bool canonicalize_instruction_names_;
int indent_amount_;
bool is_in_nested_computation_;
@@ -346,8 +359,7 @@ class HloInstruction {
const string& name);
// Creates a literal constant instruction.
- static std::unique_ptr<HloInstruction> CreateConstant(
- std::unique_ptr<Literal> literal);
+ static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
// Creates an Iota instruction.
static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index fb7345a2ad..e92882c22a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
-HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
- : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+HloConstantInstruction::HloConstantInstruction(Literal literal)
+ : HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- if (literal_ != nullptr) {
+ if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
return proto;
@@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
if (!mutable_array_subshape->has_layout() ||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
- literal_ = literal_->Relayout(new_layout, shape_index);
+ *literal_ = literal_->Relayout(new_layout, shape_index);
*mutable_array_subshape->mutable_layout() = new_layout;
}
}
@@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
+ CHECK(literal_.has_value());
+ return absl::make_unique<HloConstantInstruction>(literal_->Clone());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if (literal_ != nullptr &&
+ if (literal_.has_value() &&
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
@@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstructionProto HloTraceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_literal() = literal_->ToProto();
+ *proto.mutable_literal() = literal_.ToProto();
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index c3a7801164..2d7bc83855 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
- explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ explicit HloConstantInstruction(Literal literal);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns whether there is literal associated with this instruction.
- bool HasLiteral() const { return literal_ != nullptr; }
+ bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
- string TracingTag() const { return literal_->GetR1U8AsString(); }
+ string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ Literal literal_;
};
class HloFusionInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 9bfb0af96c..c7ec88d450 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <map>
#include <queue>
@@ -582,4 +582,22 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
size_function, nullptr, empty_map);
}
+HloMemoryScheduler::HloMemoryScheduler(
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm)
+ : size_function_(size_function), algorithm_(algorithm) {}
+
+StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, size_function_, algorithm_));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ return true;
+}
+
+StatusOr<bool> HloDescheduler::Run(HloModule* module) {
+ bool changed = module->has_schedule();
+ module->clear_schedule();
+ return changed;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 54e32340ba..5e02868eba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -86,6 +87,37 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
+// A pass which schedules the HLO instructions in a module. The HloModule's
+// schedule field is set to the resulting HloSchedule using
+// HloModule::set_schedule.
+class HloMemoryScheduler : public HloPassInterface {
+ public:
+ // size_function is the function returning the number of bytes required for a
+ // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
+ // specified, then DefaultMemoryScheduler is used.
+ HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
+ ~HloMemoryScheduler() override = default;
+ absl::string_view name() const override { return "hlo-memory-scheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ LogicalBuffer::SizeFunction size_function_;
+ MemorySchedulerAlgorithm algorithm_;
+};
+
+// A trivial pass which clears the schedule currently set on the
+// HloModule. After this pass runs HloModudle::has_schedule will return false.
+class HloDescheduler : public HloPassInterface {
+ public:
+ HloDescheduler() = default;
+ ~HloDescheduler() override = default;
+ absl::string_view name() const override { return "hlo-descheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 6afe51997e..1b9e9bfc77 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <memory>
#include <string>
@@ -67,22 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- HloSchedule schedule,
- ScheduleModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
+ HloMemoryScheduler scheduler([](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ });
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+
// Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence =
- schedule.sequence(module->entry_computation()).instructions();
+ module->schedule().sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
EXPECT_EQ(param, sequence.front());
EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(schedule);
+ SequentialHloOrdering ordering(module->schedule());
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
+
+ // Clear the schedule using the descheduling pass.
+ HloDescheduler descheduler;
+ EXPECT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
+ descheduler.Run(module.get()));
+ EXPECT_TRUE(descheduler_changed);
+ EXPECT_FALSE(module->has_schedule());
}
TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cfe906d9c5..b3949f3a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names) {
+ bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
- if (uniquify_names) {
+ if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
+
+ // Pick unique IDs for each instruction.
+ for (auto* instruction : computation->instructions()) {
+ instruction->SetUniqueId(NewUniqueInstructionId());
+ }
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
- // for computations and instructions created later.
+ // for computations and instructions created later. Also, set the
+ // next_unique_id_ to the one greater than the max unique id of any
+ // instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
+ next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+ }
+ if (next_unique_id_ < computation->unique_id() + 1) {
+ next_unique_id_ = computation->unique_id() + 1;
}
}
- // Pick unique IDs for each instruction.
- for (auto* instruction : computation->instructions()) {
- instruction->SetUniqueId(NewUniqueInstructionId());
- }
- // Set unique id to this computation.
- CHECK_NE(computation->root_instruction()->unique_id(), -1)
- << "Root has no valid id: " << computation->ToString();
- computation->SetUniqueId(computation->root_instruction()->unique_id());
-
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ VLOG(2) << "CreateFromProto()";
+ XLA_VLOG_LINES(2, proto.DebugString());
+
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
- /*uniquify_names=*/false);
+ /*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
- // Because we didn't uniquify the names, double-check that the instruction and
- // computation names are unique from the proto.
+ // Because we didn't uniquify the names or the ids, double-check that the
+ // instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
+ tensorflow::gtl::FlatSet<int> computation_ids;
+ tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
+
+ TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+ << "Computation id is not unique: " << computation->unique_id();
+ computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
+
+ TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+ << "Instruction id is not unique: " << instruction->unique_id();
+ instruction_ids.insert(instruction->unique_id());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 26fd1b2438..3bc2d13781 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -253,7 +253,7 @@ class HloModule {
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names);
+ bool uniquify_identifiers);
const string name_;
HloModuleConfig config_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index 98d20315e3..f7be5cae22 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -36,23 +36,6 @@ namespace xla {
namespace {
-bool HasSendRecv(HloComputation* computation) {
- for (auto* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kSendDone ||
- instruction->opcode() == HloOpcode::kRecv ||
- instruction->opcode() == HloOpcode::kRecvDone) {
- return true;
- }
- for (auto* sub_computation : instruction->called_computations()) {
- if (HasSendRecv(sub_computation)) {
- return true;
- }
- }
- }
- return false;
-}
-
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
bool changed = false;
for (auto* computation : module->computations()) {
@@ -68,9 +51,10 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
while_body_root->opcode() != HloOpcode::kTuple ||
- HasSendRecv(while_body_comp)) {
+ while_body_comp->HasSideEffect() ||
+ xla_while->while_condition()->HasSideEffect()) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
- // with no send/recv instructions.
+ // with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 363862e490..bf66cc6bc3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
"while.2", 1));
}
+// Tests that a while whose body has outfeed operations is not DCE-ed.
+TEST_F(HloModuleDceTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
+// Tests that if a loop variable is not referenced outside of a kWhile, the loop
+// variable changes are not elided within the loop body, if the condition
+// computation uses them.
+TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
+ auto module = ParseHloString(R"(
+ HloModule InfiniteLoop
+ WhileBody {
+ body_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
+ }
+ WhileCondition {
+ cond_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ p0 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
+ while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc
new file mode 100644
index 0000000000..f9b56ef464
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+namespace xla {
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ std::unique_ptr<HloModule> module)
+ : name_(name) {
+ push_back(std::move(module));
+}
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules)
+ : name_(name) {
+ for (auto& module : modules) {
+ push_back(std::move(module));
+ }
+}
+
+std::vector<std::unique_ptr<HloModule>> HloModuleGroup::ConsumeModules() {
+ std::vector<std::unique_ptr<HloModule>> ret_modules = std::move(modules_);
+
+ // Clear everything so the object state is in a known (empty) state.
+ modules_.clear();
+ module_ptrs_.clear();
+ return ret_modules;
+}
+
+string HloModuleGroup::ToString() const {
+ std::ostringstream s;
+ s << "HloModuleGroup " << name() << "\n\n";
+ for (const HloModule* module : modules()) {
+ s << module->ToString() << "\n";
+ }
+ return s.str();
+}
+
+HloModuleGroupProto HloModuleGroup::ToProto() const {
+ HloModuleGroupProto proto;
+ proto.set_name(name());
+ for (const HloModule* module : modules()) {
+ *proto.add_hlo_modules() = module->ToProto();
+ }
+ return proto;
+}
+
+/* static */ StatusOr<HloModuleGroup> HloModuleGroup::CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs) {
+ TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty";
+ TF_RET_CHECK(proto.hlo_modules_size() > 0)
+ << "Module group must have at least one HLO module";
+ TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size());
+
+ std::vector<std::unique_ptr<HloModule>> modules;
+ for (int i = 0; i < proto.hlo_modules_size(); ++i) {
+ const HloModuleProto& module_proto = proto.hlo_modules(i);
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(module_proto, module_configs[i]));
+ modules.push_back(std::move(module));
+ }
+
+ return HloModuleGroup(proto.name(), absl::MakeSpan(modules));
+}
+
+void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
+ modules_.push_back(std::move(module));
+ module_ptrs_.push_back(modules_.back().get());
+}
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) {
+ out << group.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h
new file mode 100644
index 0000000000..7338be8b9c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+
+#include <iosfwd>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+
+// An abstraction representing a ordered set of HLO module built to run
+// concurrently across different devices.
+class HloModuleGroup {
+ public:
+ // Construct an empty module group.
+ explicit HloModuleGroup(absl::string_view name) : name_(name) {}
+
+ // Construct a module group containing a single module.
+ HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module);
+
+ // Construct a module group containing any number of modules.
+ HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules);
+
+ // Returns the modules contained in the group.
+ const std::vector<HloModule*>& modules() const { return module_ptrs_; }
+
+ // Returns a module at a particular index.
+ HloModule& module(int index) const { return *module_ptrs_.at(index); }
+
+ // Add a module to the back of vector of modules in the group.
+ void push_back(std::unique_ptr<HloModule> module);
+
+ // Moves all modules from the group into the returned vector. After this
+ // method runs, the module group will be empty.
+ std::vector<std::unique_ptr<HloModule>> ConsumeModules();
+
+ string name() const { return name_; }
+ string ToString() const;
+
+ // Serialize the module group to/from a proto.
+ HloModuleGroupProto ToProto() const;
+ static StatusOr<HloModuleGroup> CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs);
+
+ private:
+ string name_;
+
+ // Vector of modules as std::unique_ptrs.
+ std::vector<std::unique_ptr<HloModule>> modules_;
+
+ // Vector of modules as normal pointers. This vector is kept in sync with
+ // modules_ as modules are added to the group with push_back.
+ std::vector<HloModule*> module_ptrs_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
new file mode 100644
index 0000000000..ebf790ba6f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class HloModuleGroupTest : public HloTestBase {
+ protected:
+ HloModuleGroupTest() = default;
+};
+
+TEST_F(HloModuleGroupTest, SingleModule) {
+ const string text = R"(
+HloModule simple_module
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ HloModuleGroup group(TestName(), std::move(module));
+
+ EXPECT_EQ(group.modules().size(), 1);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 1);
+ EXPECT_THAT(
+ group_copy.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
+ EXPECT_EQ(modules.size(), 1);
+ EXPECT_EQ(group.modules().size(), 0);
+}
+
+TEST_F(HloModuleGroupTest, MultipleModules) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ std::vector<std::unique_ptr<HloModule>> modules;
+ modules.push_back(std::move(module_0));
+ modules.push_back(std::move(module_1));
+ HloModuleGroup group(TestName(), absl::MakeSpan(modules));
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config(),
+ group.module(1).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 2);
+}
+
+TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ HloModuleGroup group(TestName());
+ group.push_back(std::move(module_0));
+ group.push_back(std::move(module_1));
+
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 400bd4d947..39f38b417a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -253,6 +253,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
+ // Verify that serializing then deserializing an HLO proto preserves the
+ // unique IDs of the instruction and module.
+ const string text =
+ R"(HloModule ReduceR3ToR2_module
+
+add_F32.v3 {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY ReduceR3ToR2.v3 {
+ input = f32[8,16,256]{2,1,0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+
+ // Perform various transformations on the graph:
+ //
+ // * clone the reduction function
+ // * replace use of reduction function with the clone.
+ // * add a random instruction to the entry computation.
+ //
+ // This will create instruction and computation IDs which are interesting:
+ // not consecutive and not densely packed.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+ HloComputation* reduction = root->to_apply();
+ HloComputation* reduction_clone =
+ module->AddEmbeddedComputation(reduction->Clone());
+ root->set_to_apply(reduction_clone);
+ TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
+ HloInstruction* negate = entry->AddInstruction(
+ HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
+ entry->set_root_instruction(negate);
+
+ // Schedule the transformed module, this verifies that the serialized schedule
+ // is robust against non-consecutive IDs as well (b/114712358).
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ HloMemoryScheduler scheduler(size_fn);
+ TF_ASSERT_OK(scheduler.Run(module.get()).status());
+ ASSERT_TRUE(module->has_schedule());
+
+ // Serialize and deserialize and verify that the instruction and computations
+ // unique ids are the same.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+
+ // The module IDs should *not* be the same because module ids must be globally
+ // unique.
+ EXPECT_NE(module->unique_id(), module_copy->unique_id());
+
+ // Verify that the computations and instructions all have the same unique id.
+ auto computation_copy_it = module_copy->computations().begin();
+ for (const HloComputation* computation_orig : module->computations()) {
+ const HloComputation* computation_copy = *computation_copy_it++;
+ EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original computation %s != ID of deserialized "
+ "computation %s: %d != %d",
+ computation_orig->name(), computation_copy->name(),
+ computation_orig->unique_id(), computation_copy->unique_id());
+
+ auto instruction_copy_it = computation_copy->instructions().begin();
+ for (const HloInstruction* instruction_orig :
+ computation_orig->instructions()) {
+ const HloInstruction* instruction_copy = *instruction_copy_it++;
+ EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original instruction %s != ID of deserialized "
+ "instruction %s: %d != %d",
+ instruction_orig->name(), instruction_copy->name(),
+ instruction_orig->unique_id(), instruction_copy->unique_id());
+ }
+ }
+
+ // Verify that the next unique ID which the module would have handed out is
+ // greater than the unique id of any instruction.
+ int next_id = module_copy->NewUniqueInstructionId();
+ for (const HloComputation* computation : module_copy->computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ EXPECT_GT(next_id, instruction->unique_id());
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 6b6005e7a5..00970bcda3 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index c54360b063..11caa89c54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -105,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
- bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
- bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseLiteral(Literal* literal, const Shape& shape);
+ bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+ bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
- bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
- std::unique_ptr<Literal> literal;
+ Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
- std::vector<std::unique_ptr<Literal>> elements(
- ShapeUtil::TupleElementCount(shape));
+ std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
- linear_index++, literal->get())) {
+ linear_index++, literal)) {
return false;
}
lexer_.Lex();
@@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
- *literal = (*literal)->Relayout(shape.layout());
+ *literal = literal->Relayout(shape.layout());
return true;
}
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = absl::make_unique<Literal>(shape);
+ *literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
- if ((*literal)->sparse_element_count() + 1 ==
+ if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
- (*literal)->AppendSparseElement(index, value);
+ literal->AppendSparseElement(index, value);
}
- (*literal)->SortSparseElements();
+ literal->SortSparseElements();
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 585c95972b..d9848cee0b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
namespace xla {
namespace {
-class HloReachabilityTest : public HloTestBase {};
+class HloReachabilityTest : public HloVerifiedTestBase {};
TEST_F(HloReachabilityTest, Reachability) {
// Construct and test a reachability graph of the following form:
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 0a0a6a323e..bd6dd79b67 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -27,15 +27,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -1194,51 +1193,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(HloModule* module,
- HloSchedule* schedule,
- int64 memory_limit_bytes,
- RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The schedule is constructed entirely by this method.
- TF_RET_CHECK(schedule->empty());
-
+StatusOr<bool> HloRematerialization::Run(HloModule* module) {
VLOG(1) << "HloRematerialization() with memory limit of "
- << HumanReadableNumBytes(memory_limit_bytes);
+ << HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial schedule of HLO instructions.
- TF_ASSIGN_OR_RETURN(*schedule,
- ScheduleModule(*module,
- [this](const BufferValue& buffer) {
- return size_function_(buffer.shape());
- },
- scheduler_algorithm_));
- if (copy_insertion) {
- // We run a separate pass of copy elision here because the sequential
- // ordering from the HLO schedule allows for more copies to be eliminated.
- // TODO(b/80249101): Instead of a separate copy elision pass, use the
- // ordering from the HLO schedule directly for copy insertion.
- SequentialHloOrdering ordering(*schedule);
- TF_RETURN_IF_ERROR(
- copy_insertion->RemoveUnnecessaryCopies(ordering, module));
-
- // RemoveUnnecessaryCopies only considers interference when determining
- // whether it is legal to remove a copy. However, copies in the graph may be
- // necessary for other reason such as preventing a constant from being live
- // out of the graph. So run AddSpecialCaseCopies to re-insert these copies.
- // TODO(b/80249101): Break copy insertion into several passes and run each
- // one once in the regular HLO pipeline.
- TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module));
-
- // The passes above can add and remove copies, update the schedule to
- // account for these transformations. Newly added instructions will be
- // placed ASAP in the schedule.
- TF_RETURN_IF_ERROR(schedule->Update());
-
- TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
- SequentialHloOrdering(*schedule), module));
- }
-
+ TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
// Adjust memory limit to account for the output of the entry
@@ -1254,7 +1214,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
});
const int64 adjusted_memory_limit_bytes =
- memory_limit_bytes - module_output_size;
+ memory_limit_bytes_ - module_output_size;
VLOG(1) << "Adjusted memory limit accounting for output ("
<< HumanReadableNumBytes(module_output_size)
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
@@ -1263,13 +1223,14 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, schedule](const CallGraphNode& node) -> Status {
+ [this, module](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
- ComputePeakMemory(
- node.computation(),
- schedule->sequence(node.computation()).instructions()));
+ ComputePeakMemory(node.computation(),
+ module->schedule()
+ .sequence(node.computation())
+ .instructions()));
}
return Status::OK();
},
@@ -1287,9 +1248,10 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// Subcomputations called by the entry computation will also be
// rematerialized.
- TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), schedule,
- adjusted_memory_limit_bytes));
+ TF_ASSIGN_OR_RETURN(
+ bool changed,
+ RematerializeComputation(module->entry_computation(), &module->schedule(),
+ adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@@ -1298,7 +1260,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// After DCE, the module sequence may include instructions which no longer
// exist.
- TF_RETURN_IF_ERROR(schedule->Update());
+ TF_RETURN_IF_ERROR(module->schedule().Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1315,32 +1277,22 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
<< reduced_peak_memory << " bytes)";
- if (sizes != nullptr) {
- sizes->before_bytes = before_peak_memory;
- sizes->after_bytes = current_peak_memory;
+ if (sizes_ != nullptr) {
+ sizes_->before_bytes = before_peak_memory;
+ sizes_->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
- if (current_peak_memory > memory_limit_bytes) {
+ if (current_peak_memory > memory_limit_bytes_) {
LOG(WARNING) << absl::StrFormat(
"Can't reduce memory use below %s (%d bytes) by rematerialization; "
"only reduced to %s (%d bytes)",
- HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes,
+ HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
}
-/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
- const HloRematerialization::ShapeSizeFunction& size_function,
- int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
- HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes,
- copy_insertion);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index fa0414b472..e2aaf18b3e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,17 +17,23 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
namespace xla {
-class HloRematerialization {
+// HLO pass which rematerializes instructions to reduce peak memory use, where
+// memory use is defined as the total size of all live HLO instruction
+// values. Parameters and constants are included in memory use estimates.
+//
+// CSE will undo the effects of this optimization and should not be run after
+// this pass. In general, this pass should be run very late, immediately before
+// code generation.
+class HloRematerialization : public HloPassInterface {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
@@ -38,10 +44,7 @@ class HloRematerialization {
int64 after_bytes;
};
- // Rematerialize HLO instructions in the given module to reduce peak memory
- // use below memory_limit_bytes where memory use is defined as the total size
- // of all live HLO instruction values. Parameters and constants are included
- // in memory use estimates. Method parameters:
+ // Constructor parameters:
//
// size_function: Function which returns the size in bytes of the top-level
// buffer of the given shape.
@@ -49,51 +52,27 @@ class HloRematerialization {
// memory_limit_bytes: The threshold number of bytes to reduce memory use to
// via rematerialization.
//
- // hlo_module: HLO module to rematerialize instructions in.
- //
- // schedule: Should point to an empty HloSchedule. Upon return
- // contains the HLO instruction order which was used for
- // rematerialization. This is the order in which HLO instructions should
- // be emitted to minimize memory use.
- //
- // sizes: Optional outparam that indicates the peak memory usage of the HLO
- // module before/after rematerialization.
- //
- // copy_insertion: If non-null, run copy elision after scheduling. This
- // pass is used to eliminate copies that were inserted by copy insertion
- // before HLO scheduling.
- //
- // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
- // insertion is integrated with HLO scheduling.
- //
- // Returns whether any instructions were rematerialized. If memory use is
- // already below the given limit then no instructions are rematerialized and
- // false is returned.
- //
- // CSE will undo the effects of this optimization and should not be run after
- // this pass. In general, this pass should be run very late immediately before
- // code generation.
- static StatusOr<bool> RematerializeAndSchedule(
- const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
- HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- HloSchedule* schedule, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion = nullptr);
-
- protected:
- HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
- const ShapeSizeFunction& size_function)
- : scheduler_algorithm_(scheduler_algorithm),
- size_function_(size_function) {}
+ // sizes: Pointer to data structure which records the peak memory usage of
+ // the HLO module before/after rematerialization. Value are set during
+ // Run(). Can be nullptr.
+ HloRematerialization(const ShapeSizeFunction& size_function,
+ int64 memory_limit_bytes, RematerializationSizes* sizes)
+ : size_function_(size_function),
+ memory_limit_bytes_(memory_limit_bytes),
+ sizes_(sizes) {}
~HloRematerialization() {}
+ absl::string_view name() const override { return "rematerialization"; }
+
// Runs rematerialization on the given module. Returns whether the module was
- // changed. memory_limit is the target maximum peak memory usage by the
- // module. schedule should be an empty HloSchedule. Upon return sequence
- // contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module, HloSchedule* schedule,
- int64 memory_limit, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion);
+ // changed. Requires that the module has a schedule set
+ // (HloModule::has_schedule() is true) before running. Returns whether any
+ // instructions were rematerialized. If memory use is already below the limit
+ // specified in the constructor then no instructions are rematerialized and
+ // false is returned.
+ StatusOr<bool> Run(HloModule* module) override;
+ protected:
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
@@ -121,6 +100,14 @@ class HloRematerialization {
// Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
+ // The threshold number of bytes to reduce memory use to via
+ // rematerialization.
+ const int64 memory_limit_bytes_;
+
+ // Pointer to data structure which records the peak memory usage of the HLO
+ // module before/after rematerialization
+ RematerializationSizes* sizes_;
+
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83cb113bfb..f7e82fb1f8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
@@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase {
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
- HloModule* module,
- HloSchedule* schedule) {
+ HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
- return HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- schedule, /*sizes=*/nullptr);
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ DefaultMemoryScheduler);
+ TF_EXPECT_OK(scheduler.Run(module).status());
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr);
+ return remat.Run(module);
}
// Various shapes used in the canned computations.
@@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024, module));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024, module));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024, module));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024, module));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &schedule));
+ bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 66ac1f66fd..fa7f216321 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals) {
+ const absl::Span<const Literal> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
- literal_pointers.push_back(literal.get());
+ literal_pointers.push_back(&literal);
}
return TransferLiteralsToDevice(literal_pointers);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
+StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
TF_ASSIGN_OR_RETURN(
auto stream, backend().BorrowStream(backend().default_stream_executor()));
@@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
buffer);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+StatusOr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
@@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return TransferLiteralFromDevice(result);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
- argument_pointers.push_back(argument.get());
+ argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
@@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
/*profile=*/profile);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(
@@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = absl::make_unique<Literal>();
+ Literal literal;
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, options.outfeed_shape, literal.get()));
+ executor, options.outfeed_shape, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}
@@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
- std::vector<std::unique_ptr<Literal>> exec_results;
+ std::vector<Literal> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 76d8b92bed..2e934bf66a 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -72,7 +72,7 @@ class HloRunner {
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
- std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+ std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
@@ -106,24 +106,23 @@ class HloRunner {
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals);
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- const ShapedBuffer& buffer);
+ const absl::Span<const Literal> literals);
+ StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const Literal* const> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal* const> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
@@ -140,7 +139,7 @@ class HloRunner {
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
- StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+ StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
index eb52582bb5..1424569ac1 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a1f2..6fd734a2b9 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -24,7 +24,7 @@ namespace {
using ::tensorflow::GraphDef;
-class HloTfGraphBuilderTest : public HloTestBase {
+class HloTfGraphBuilderTest : public HloVerifiedTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 069586a738..50f39cbcb5 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1123,6 +1123,11 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+ // If the module has a schedule, it must be valid.
+ if (module->has_schedule()) {
+ TF_RETURN_IF_ERROR(module->schedule().Verify());
+ }
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0cac210c24..8f0423bb1c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
padding_config));
auto module = CreateNewModule();
@@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
padding_config));
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 37b774b8a5..06f0e1ed25 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> inner_broadcast_result,
+ Literal inner_broadcast_result,
broadcast_const_operand->literal().Broadcast(
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
@@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+ opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+ opcode, inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 9746d176cc..df9cbab915 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
}
}
- Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+ Literal* TakeOwnership(Literal literal) {
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
- StatusOr<Literal*> TakeOwnership(
- StatusOr<std::unique_ptr<Literal>> literal_or_error) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- std::move(literal_or_error));
+ StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
std::vector<std::unique_ptr<Array>> owned_tensors_;
- std::vector<std::unique_ptr<Literal>> owned_literals_;
+ std::vector<Literal> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
};
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 5695bc2420..7e967f035c 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using InlinerTest = HloTestBase;
+using InlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
TEST_F(InlinerTest, MapMax) {
@@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
// Test that `constant` function is changed to `broadcast`.
@@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 8c907eae0c..3fdc2cee9a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible(
return do_not_duplicate;
}
+namespace {
+
+// A FusionQueue that uses reverse post order.
+//
+// We want to be able to remove arbitrary instructions from the post order and
+// also compare positions of instructions in the post order. To make this
+// possible, create vector of instructions in post order and create a map from
+// HloInstruction* to the instruction's index in the vector. An instruction is
+// "removed" from the vector by setting it's element to nullptr.
+class ReversePostOrderFusionQueue : public FusionQueue {
+ public:
+ explicit ReversePostOrderFusionQueue(HloComputation* computation) {
+ post_order_ = computation->MakeInstructionPostOrder();
+
+ for (size_t i = 0; i < post_order_.size(); ++i) {
+ InsertOrDie(&post_order_index_, post_order_[i], i);
+ }
+ }
+
+ std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() override {
+ // Instructions are "removed" from the post order by nulling out the element
+ // in the vector, so if the pointer is null, continue to the next
+ // instruction in the sort.
+ while (!post_order_.empty() && post_order_.back() == nullptr) {
+ post_order_.pop_back();
+ }
+ if (post_order_.empty()) {
+ return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
+ }
+ // We want to iterate in reverse post order, so remove from the back of the
+ // vector.
+ HloInstruction* instruction = post_order_.back();
+ post_order_.pop_back();
+
+ CHECK(instruction != nullptr);
+ // Remove instruction from the index map to ensure the vector and map stay
+ // consistent.
+ post_order_index_.erase(instruction);
+
+ // Consider each operand of this instruction for fusion into this
+ // instruction. We want to consider the operands in a particular order to
+ // avoid creating duplicate instruction clones in the fusion instruction.
+ // For example, consider the following expression:
+ //
+ // A = ...
+ // B = op(A)
+ // C = op(A, B)
+ //
+ // If we are considering the operands of C for fusion into C. We might
+ // fuse A or B first. If we fuse A first, we get:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // C' = op(A', B) }
+ //
+ // Where A' and C' are clones of A and C, respectively. Now only B is an
+ // operand of the fusion instruction C_fusion, so then we fuse B:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // B' = op(A)
+ // C' = op(A', B') }
+ //
+ // Now A is an operand of C_fusion again, so we then fuse A (again!):
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // A" = ..
+ // B' = op(A")
+ // C' = op(A', B') }
+ //
+ // We prevent this duplication by considering the operands in the order
+ // they appear int the queue. In the example, this ensures that B will be
+ // considered before A.
+ //
+ // We store the original indices of the operands to pass to ShouldFuse.
+ std::vector<int64> sorted_operand_numbers;
+ sorted_operand_numbers.reserve(instruction->operands().size());
+ for (int i = 0; i < instruction->operands().size(); ++i) {
+ // This will happen if we have two possible instructions to fuse the
+ // same operand into; once the operand is fused into one instruction,
+ // the other instruction will get a new get-tuple-element as its
+ // operand, which is not in the queue.
+ // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+ if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
+ continue;
+ }
+ sorted_operand_numbers.push_back(i);
+ }
+ std::sort(
+ sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+ [&](int64 i, int64 j) {
+ // Instructions with higher priority in the queue come first.
+ return (
+ FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
+ FindOrDie(post_order_index_, instruction->mutable_operand(j)));
+ });
+ return std::make_pair(instruction, sorted_operand_numbers);
+ }
+
+ void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) override {
+ // Fusing an instruction into a fusion instruction can change the operand
+ // set of the fusion instruction. For simplicity just re-enqueue the
+ // instruction and reconsider it for further fusion in the next iteration.
+ InsertOrDie(&post_order_index_, fusion, post_order_.size());
+ post_order_.push_back(fusion);
+ }
+
+ void RemoveInstruction(HloInstruction* instruction) override {
+ post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
+ post_order_index_.erase(instruction);
+ }
+
+ private:
+ std::vector<HloInstruction*> post_order_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+};
+
+} // namespace
+
+std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer) {
+ return absl::make_unique<ReversePostOrderFusionQueue>(computation);
+}
+
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
VLOG(2) << "Before instruction fusion:";
XLA_VLOG_LINES(2, module->ToString());
@@ -306,111 +439,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
computation_ = computation;
reachability_ = computation_->ComputeReachability();
- // We want to be able to remove arbitrary instructions from the post order
- // and also compare positions of instructions in the post order. To make
- // this possible, create vector of instructions in post order and create a
- // map from HloInstruction* to the instruction's index in the vector. An
- // instruction is "removed" from the vector by setting it's element to
- // nullptr.
- std::vector<HloInstruction*> post_order =
- computation_->MakeInstructionPostOrder();
-
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
- for (size_t i = 0; i < post_order.size(); ++i) {
- InsertOrDie(&post_order_index, post_order[i], i);
- }
-
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
+ HloInstructionSet do_not_duplicate =
+ ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
+ auto fusion_queue =
+ GetFusionQueue(computation_, [&](HloInstruction* producer) {
+ return do_not_duplicate.count(producer) > 0;
+ });
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
// edges. When we fuse an edge, we create a copy of the producer inside the
// fusion instruction.
- while (!post_order.empty()) {
- // We want to iterate in reverse post order, so remove from the back of
- // the vector.
- HloInstruction* instruction = post_order.back();
- post_order.pop_back();
-
- // Instructions are "removed" from the post order by nulling out the
- // element in the vector, so if the pointer is null, continue to the next
- // instruction in the sort.
+ while (true) {
+ auto next_entry =
+ fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
+ auto instruction = next_entry.first;
if (instruction == nullptr) {
- continue;
+ break;
}
- // Remove instruction from the index map to ensure the vector and map stay
- // consistent.
- post_order_index.erase(instruction);
-
if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- // Consider each operand of this instruction for fusion into this
- // instruction. We want to consider the operands in a particular order to
- // avoid creating duplicate instruction clones in the fusion instruction.
- // For example, consider the following expression:
- //
- // A = ...
- // B = op(A)
- // C = op(A, B)
- //
- // If we are considering the operands of C for fusion into C. We might
- // fuse A or B first. If we fuse A first, we get:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // C' = op(A', B) }
- //
- // Where A' and C' are clones of A and C, respectively. Now only B is an
- // operand of the fusion instruction C_fusion, so then we fuse B:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // B' = op(A)
- // C' = op(A', B') }
- //
- // Now A is an operand of C_fusion again, so we then fuse A (again!):
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // A" = ..
- // B' = op(A")
- // C' = op(A', B') }
- //
- // We prevent this duplication by considering the operands in the reverse
- // order they appear in the instruction post order. In the example, this
- // ensures that B will be considered before A.
- //
- // We store the original indices of the operands to pass to ShouldFuse.
- std::vector<int64> sorted_operand_numbers;
- sorted_operand_numbers.reserve(instruction->operands().size());
- for (int i = 0; i < instruction->operands().size(); ++i) {
- // This will happen if we have two possible instructions to fuse the
- // same operand into; once the operand is fused into one instruction,
- // the other instruction will get a new get-tuple-element as its
- // operand, which is not in the post-order index.
- // TODO(tjoerg): Look into fusing past these multi-output fuse points.
- if (post_order_index.find(instruction->mutable_operand(i)) ==
- post_order_index.end()) {
- continue;
- }
- sorted_operand_numbers.push_back(i);
- }
- std::sort(
- sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
- [&](int64 i, int64 j) {
- // Instructions with higher indices in the post order come
- // first.
- return (
- FindOrDie(post_order_index, instruction->mutable_operand(i)) >
- FindOrDie(post_order_index, instruction->mutable_operand(j)));
- });
+ std::vector<int64>& sorted_operand_numbers = next_entry.second;
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
@@ -425,32 +478,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// TODO(tjoerg): Consider making multi-output fusion the default.
if (ShouldFuse(instruction, i) &&
do_not_duplicate.count(operand) == 0) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = Fuse(operand, instruction);
} else if (ShouldFuseIntoMultiOutput(instruction, i) &&
!MultiOutputFusionCreatesCycle(operand, instruction)) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = FuseIntoMultiOutput(operand, instruction);
} else {
continue;
}
- // Fusing an instruction into a fusion instruction can change the
- // operand set of the fusion instruction. For simplicity just push the
- // instruction to the top of the post_order and reconsider it for
- // further fusion in the next iteration of the outer loop.
- post_order.push_back(fusion_instruction);
- InsertOrDie(&post_order_index, fusion_instruction,
- post_order.size() - 1);
+ fusion_queue->OnFusingInstruction(fusion_instruction, operand,
+ instruction);
changed = true;
if (operand->user_count() == 0) {
- // Operand is now dead. Remove from post order by setting its
- // location to nullptr.
- post_order[FindOrDie(post_order_index, operand)] = nullptr;
- post_order_index.erase(operand);
-
+ do_not_duplicate.erase(operand);
+ // Operand is now dead. Remove from queue.
+ fusion_queue->RemoveInstruction(operand);
// Remove from computation.
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
}
+
+ if (fusion_instruction != instruction) {
+ do_not_duplicate.erase(instruction);
+ }
break;
}
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 00b658959a..c1fde8ecfc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -24,6 +24,33 @@ limitations under the License.
namespace xla {
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
@@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface {
static bool IsExpensive(const HloInstruction& instruction);
protected:
+ // Returns a FusionQueue that implements custom order of instructions being
+ // fused. The default implementation processes consumers in reverse post
+ // order.
+ virtual std::unique_ptr<FusionQueue> GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer);
+
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
// Derived classes should define this method to specify which instructions
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 5dea124768..a06d6113e8 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// Transform the ShapedBuffer arguments into literals which the evaluator
// consumes.
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
// Execute the graph using the HloEvaluator.
- std::unique_ptr<Literal> result_literal;
+ Literal result_literal;
{
tensorflow::mutex_lock lock(evaluator_lock_);
- TF_ASSIGN_OR_RETURN(result_literal,
- evaluator_->Evaluate<std::unique_ptr<Literal>>(
- *computation, arg_literals));
+ TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
+ *computation, arg_literals));
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
transfer_manager->AllocateScopedShapedBuffer(
- result_literal->shape(), run_options->allocator(),
+ result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- run_options->stream(), *result_literal, result));
+ run_options->stream(), result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 69c7e42601..752a61476d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,7 +49,7 @@ namespace {
using ::testing::ElementsAre;
-class LayoutAssignmentTest : public HloTestBase {
+class LayoutAssignmentTest : public HloVerifiedTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout,
@@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
@@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
- Shape ashape = constant_literal1->shape();
+ Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
@@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
@@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
@@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
@@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
EXPECT_TRUE(
AlgebraicSimplifier(/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return false; })
- .Run(module.get())
+ .Run(module)
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
@@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
@@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
@@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_4, "param"));
auto broadcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_34, param, {3}));
+ HloInstruction::CreateBroadcast(f32_34, param, {1}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
auto broadcast2 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+ HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
auto module = CreateNewModule();
@@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
+ EXPECT_IS_OK(layout_assignment.Run(module).status());
EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
- module =
+ std::unique_ptr<HloModule> compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
EXPECT_EQ(Status::OK(), backend()
.compiler()
- ->RunBackend(std::move(module),
+ ->RunBackend(std::move(compiled_module),
backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.status());
@@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
ShapeUtil::MakeTupleShape({
@@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({2, 1, 0}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(&module(), &computation_layout);
- EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(1)
.layout()
@@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
HloComputation* computation = module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
const HloInstruction* true_root = true_computation->root_instruction();
const HloInstruction* false_root = false_computation->root_instruction();
@@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
- Status error_status = layout_assignment.Run(module.get()).status();
+ Status error_status = layout_assignment.Run(module).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
@@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
@@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
- AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+ AssignLayouts(&module(), &computation_layout, &channel_constraints);
- EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
- EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::GetSubshape(
- FindInstruction(module.get(), "send")->shape(), {0}),
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+ EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
@@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest,
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index f0e2566a3f..b27a92f2a0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, *argument));
- *module->add_arguments() = literal->ToProto();
+ *module->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, result));
- *module->mutable_result() = literal->ToProto();
+ *module->mutable_result() = literal.ToProto();
return Status::OK();
}
@@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
@@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> result_literal,
+ Literal result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
- if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
- result_literal->shape())) {
- *result->mutable_literal() = result_literal->ToProto();
+ if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
+ *result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
- result_literal->Relayout(*return_shape)->ToProto();
+ result_literal.Relayout(*return_shape).ToProto();
}
return Status::OK();
}
@@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
@@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- stream.get(), *literal, shaped_buffer));
+ stream.get(), literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
- executor, *literal);
+ return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
+ literal);
}
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
- executor, arg->shape_with_layout(), *literal));
- *result->mutable_literal() = literal->ToProto();
+ executor, arg->shape_with_layout(), literal));
+ *result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModule::CreateFromProto(arg->computation(), config));
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(auto result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, /*arg_literals=*/{}));
+ TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
+ *module, /*arg_literals=*/{}));
// Since the result layout is non-effective to the Evaluator results, explicit
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
- result_literal = result_literal->Relayout(arg->output_layout());
+ result_literal = result_literal.Relayout(arg->output_layout());
}
- *result->mutable_literal() = result_literal->ToProto();
+ *result->mutable_literal() = result_literal.ToProto();
return Status::OK();
}
@@ -1162,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
return replicas;
}
-Status Service::MaybeDumpHloModule(const HloModule& module) const {
+Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const {
const string xla_dump_unoptimized_hlo_proto_to =
module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
if (xla_dump_unoptimized_hlo_proto_to.empty()) {
@@ -1170,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const {
}
HloProto proto = MakeHloProto(module);
return protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_unoptimized_hlo_proto_to, module.name());
+ proto, xla_dump_unoptimized_hlo_proto_to,
+ StrCat(module.name(), ".unoptimized"));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 44c5248b15..1f62fad4c8 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -271,7 +271,9 @@ class Service : public ServiceInterface {
StatusOr<std::vector<se::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
- Status MaybeDumpHloModule(const HloModule& module) const;
+ // Dumps the (unoptimized) module given if the corresponding DebugOptions
+ // field has been set.
+ Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
deleted file mode 100644
index dd53c7531b..0000000000
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/source_map_util.h"
-
-#include "absl/strings/str_format.h"
-#include "tensorflow/compiler/xla/util.h"
-
-namespace xla {
-namespace source_map_util {
-namespace {
-
-Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
- const char* format, va_list args) {
- string message;
- tensorflow::strings::Appendv(&message, format, args);
- if (!op_metadata.source_file().empty()) {
- absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
- op_metadata.source_line());
- }
- return InvalidArgument("%s", message);
-}
-
-} // namespace
-
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- Status result = InvalidParameterArgumentV(op_metadata, format, args);
- va_end(args);
- return result;
-}
-
-Status InvalidParameterArgument(Executable* executable, int parameter_number,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- if (executable != nullptr && executable->has_module()) {
- const HloModule& module = executable->module();
- const HloComputation& computation = *module.entry_computation();
- HloInstruction* param = computation.parameter_instruction(parameter_number);
- const OpMetadata& metadata = param->metadata();
- Status result = InvalidParameterArgumentV(metadata, format, args);
- va_end(args);
- return result;
- }
- Status result = InvalidArgumentV(format, args);
- va_end(args);
- return result;
-}
-
-} // namespace source_map_util
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index b8d2d546e5..a21e586efa 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
@@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferLiteralFromDevice(
@@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
return substream->BlockHostUntilDone();
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+StatusOr<Literal> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
@@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferArrayToDevice(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 21725946b3..f952e64af2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -57,7 +57,7 @@ class TransferManager {
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
- virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ virtual StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
virtual Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
@@ -113,9 +113,9 @@ class TransferManager {
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source);
+ StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
+ const Shape& shape,
+ const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2b2a2eb42a..e9a07b14ed 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+ Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})};
+ auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index 39b693872d..516754e211 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class TupleSimplifierTest : public HloTestBase {
+class TupleSimplifierTest : public HloVerifiedTestBase {
protected:
void Run(HloModule* module, bool change_expected) {
TupleSimplifier simplifier;
@@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
@@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTuple) {
@@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
EXPECT_THAT(computation->root_instruction(), gte);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param1);
}
@@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
EXPECT_THAT(computation->root_instruction(),
op::Negate(op::GetTupleElement(op::Tuple())));
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
}
@@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
EXPECT_THAT(computation->root_instruction(), element);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), tuple_param);
}
@@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
EXPECT_THAT(computation->root_instruction(), tuple);
}
@@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
entry = module->AddEntryComputation(builder.Build());
}
- Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+ Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
EXPECT_THAT(c0->root_instruction(), p0);
EXPECT_THAT(c1->root_instruction(), p1);
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index c3c2603c7e..541b117e02 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
+ StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
- std::unique_ptr<Literal> indvar_iter_val =
- std::move(indvar_init_result).ValueOrDie();
+ Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(
- while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
+ if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
- StatusOr<std::unique_ptr<Literal>> indvar_next_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update,
- {{while_body_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 52c895e8d4..df610102b4 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -224,14 +224,13 @@ class ShapeTree {
// REQUIRES: index must exist in the ShapeTree.
iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
- return iterator(&nodes_, typename std::vector<Node>::iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.begin() + (element - &nodes_[0]);
+ return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
- return iterator(&nodes_,
- typename std::vector<Node>::const_iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
+ return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
// Returns the number of leaf nodes in the tree.
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 9772c06bce..96c80fd577 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return count;
}
+/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
+ PrimitiveType primitive_type) {
+ if (shape.element_type() == primitive_type) {
+ return true;
+ }
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ if (HasPrimitiveType(element_shape, primitive_type)) {
+ return true;
+ }
+ }
+ return false;
+}
+
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 8234fcdd3f..623ae39de8 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -180,6 +180,10 @@ class ShapeUtil {
// As ElementsIn(), but recurses through tuples.
static int64 ElementsInRecursive(const Shape& shape);
+ // Returns true if shape has the primitive type, recurses through tuples.
+ static bool HasPrimitiveType(const Shape& shape,
+ PrimitiveType primitive_type);
+
// Returns true if 'shape' is an array with zero elements.
static bool IsZeroElementArray(const Shape& shape);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 6ca4085aaf..c622ecdca1 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
+TEST(ShapeUtilTest, HasPrimitiveType) {
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
+ S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
+ S16));
+}
+
TEST(ShapeUtilTest, IsZeroElementArray) {
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index d0bda45cf8..30e3077edb 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -647,6 +647,7 @@ xla_test(
],
shard_count = 48,
tags = [
+ "broken",
"manual",
"notap",
],
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 0bf4556b43..c257566fb2 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -41,7 +41,6 @@ limitations under the License.
namespace xla {
namespace {
-
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001, 0.0001};
@@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<uint64> rhs{1,
0x7FFFFFFFFFFFFFFLL,
@@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Add(lhs_param, rhs_param);
@@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<int64> rhs{-1,
0,
@@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Sub(lhs_param, rhs_param);
@@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
XlaBuilder b(TestName());
std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
Lt(lhs_param, rhs_param);
- ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
+ ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
}
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
@@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a_constant = ConstantR1<float>(&builder, a_values);
- auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
+ auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+ Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
auto b_param = ConstantR1<float>(&builder, b_values);
auto sum1 = Add(a_constant, b_constant);
@@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+ Literal param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto sum = ConstantR0<float>(&b, 0.0f);
- auto param = Parameter(&b, 0, param_literal->shape(), "param");
+ auto param = Parameter(&b, 0, param_literal.shape(), "param");
for (float exponent : exponents) {
sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
}
@@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Pow(Exp(param0), param1);
std::vector<float> expected(values0.size());
@@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Log(Pow(param0, param1));
std::vector<float> expected(values0.size());
@@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Mul(Exp(param0), Exp(param1));
std::vector<float> expected(values0.size());
@@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Div(param0, Exp(param1));
std::vector<float> expected(values0.size());
@@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(Div(param0, param1), param2);
std::vector<float> expected(values0.size());
@@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Div(param1, param2));
std::vector<float> expected(values0.size());
@@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Pow(param1, param2));
std::vector<float> expected(values0.size());
@@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
+ Literal literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
- client_->TransferToServer(*literal3).ConsumeValueOrDie();
+ client_->TransferToServer(literal3).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
- auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
+ auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
Div(Div(param0, param1), Div(param2, param3));
std::vector<float> expected(values0.size());
@@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
@@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
Array3D<float> expected(0, 7, 0);
@@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
- auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
Add(a, p);
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
@@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
-0.79, 1.41, 1.21, 1.05});
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Tanh(input);
ComputeAndCompareR1<float>(
@@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Exp(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::exp(input_literal->Get<float>({i})));
+ expected_result.push_back(std::exp(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Log(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::log(input_literal->Get<float>({i})));
+ expected_result.push_back(std::log(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
- LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
@@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- auto a = ConstantLiteral(&builder, *a_literal);
+ Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ auto a = ConstantLiteral(&builder, a_literal);
auto b = ConstantR1<float>(&builder, r1);
Add(a, b, {1});
@@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
- auto y = Parameter(&builder, 1, y_literal->shape(), "y");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
+ auto y = Parameter(&builder, 1, y_literal.shape(), "y");
auto slice = Slice(x, {1}, {2}, {1});
Sub(slice, y);
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index ac90a3adb6..bc2ba151a3 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
+ input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(
- {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({0, 0}).get(),
- LiteralUtil::CreateR1<float>({16, 20}).get()});
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
+ LiteralUtil::CreateR1<float>({0, 0}),
+ LiteralUtil::CreateR1<float>({16, 20})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
@@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
- auto expected = LiteralUtil::MakeTuple(
- {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
- LiteralUtil::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_normalized, LiteralUtil::CreateR1<float>(mean),
+ LiteralUtil::CreateR1<float>(var)});
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
BatchNormTraining(input_activations, scale_activations, offset_activations,
epsilon, feature_index);
@@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
- &builder, *expected,
+ &builder, expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
@@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
+ auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
auto variance_activations =
- Parameter(&builder, 4, var_literal->shape(), "variance");
+ Parameter(&builder, 4, var_literal.shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
BatchNormInference(input_activations, scale_activations, offset_activations,
mean_activations, variance_activations, epsilon,
@@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
auto grad_output_literal =
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
- auto input_parameter =
- Parameter(&builder, 0, input_literal->shape(), "input");
- auto scale_parameter =
- Parameter(&builder, 1, scale_literal->shape(), "scale");
- auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
- auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
+ auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
+ auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
+ auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
+ auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
auto grad_output_parameter =
- Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
+ Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> var_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> grad_output_data =
- client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+ client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
grad_output_parameter, epsilon, feature_index);
- auto expected =
- LiteralUtil::MakeTuple({expected_grad_activation.get(),
- LiteralUtil::CreateR1<float>(grad_scale).get(),
- LiteralUtil::CreateR1<float>(grad_offset).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
+ LiteralUtil::CreateR1<float>(grad_offset)});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 65589b0d6a..e9728e636f 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
- {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
- .get(),
+ {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
- .get()});
+ {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
- {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
- .get(),
+ {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
- .get()});
+ {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index fe4267c73b..dde19fb65d 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
- client_->TransferToServer(*r3_data).ConsumeValueOrDie();
+ client_->TransferToServer(r3_data).ConsumeValueOrDie();
return r3_global_data;
}
@@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
- client_->TransferToServer(*r2_data).ConsumeValueOrDie();
+ client_->TransferToServer(r2_data).ConsumeValueOrDie();
return r2_global_data;
}
@@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
@@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R3ImplicitBroadcastSpec {
@@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
- {r3_implicit_global_data.get(), r3_global_data.get()},
+ &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
ErrorSpec(1e-7, 1e-7));
}
@@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
+ ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R2ImplicitBroadcastSpec {
@@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
+ &builder, expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
r2_implicit_global_data2.get()},
ErrorSpec(1e-6, 1e-6));
@@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 74d4d2eb10..9966e4606e 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
error_spec_));
}
@@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralSlice(*result, {0}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ LiteralSlice(result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralSlice(*result, {1}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ LiteralSlice(result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
error_spec_));
}
@@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
error_spec_));
}
@@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
- *result, error_spec_));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index b1d18210ea..8b31e53707 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
- auto constant =
- ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
+ auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
- auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
- auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
+ auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
auto x =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
auto y =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
@@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
auto elem = LiteralUtil::CreateR0<float>(42.0);
- auto tuple = LiteralUtil::MakeTuple({elem.get()});
- Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
+ auto tuple = LiteralUtil::MakeTuple({&elem});
+ Call(&builder, callee, {ConstantLiteral(&builder, elem)});
- ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
+ ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index a4eb57fc7b..2f1510ff69 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
- auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
Add(p0, p1);
auto param0_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto param1_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
@@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
auto computation = computation_status.ConsumeValueOrDie();
auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
- auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+ auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
- client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+ client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
- auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+ auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
// Match
auto status = client_->Execute(
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 8a236db0ff..fbdf0fcb65 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
return client_->Execute(computation, arguments, &execution_options_);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
&execution_options);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
@@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>>
-ClientLibraryTestBase::ExecuteAndTransferReference(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
if (!result.ok()) {
return result.status().ToString();
} else {
- return result.ValueOrDie()->ToString();
+ return result.ValueOrDie().ToString();
}
}
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const string& error_message)>& verify_output) {
// Try with no layout requirement.
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
- verify_output(*actual, "");
+ verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
- verify_output(*actual,
+ verify_output(actual,
absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
@@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_ASSIGN_OR_RETURN(auto literal,
client_->Transfer(*arguments[index], nullptr));
// Skip tuples because they don't have a rank.
- if (ShapeUtil::IsTuple(literal->shape())) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal->shape()));
+ ShapeUtil::HumanStringWithLayout(literal.shape()));
arguments_with_layout.push_back(arguments[index]);
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
- std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
+ std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =
- literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
+ literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
+ ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
TF_ASSIGN_OR_RETURN(auto data,
- client_->TransferToServer(*literal_relayout));
+ client_->TransferToServer(literal_relayout));
arguments_with_layout.push_back(data.get());
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
for (const auto& str : layout_strings) {
absl::StrAppend(&error_message, str, " ");
}
- verify_output(*actual, error_message);
+ verify_output(actual, error_message);
return Status::OK();
};
@@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
return Status::OK();
}
@@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
return Status::OK();
}
@@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+ Literal expected_literal = LiteralUtil::CreateR1U8(expected);
- VLOG(1) << "expected: " << expected_literal->ToString();
- VLOG(1) << "actual: " << actual->ToString();
+ VLOG(1) << "expected: " << expected_literal.ToString();
+ VLOG(1) << "actual: " << actual.ToString();
- EXPECT_EQ(expected, actual->GetR1U8AsString());
+ EXPECT_EQ(expected, actual.GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
}
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+StatusOr<std::pair<Literal, Literal>>
ClientLibraryTestBase::ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
@@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return ConstantLiteral(builder, use_bfloat16_
- ? *LiteralUtil::ConvertF32ToBF16(literal)
- : literal);
+ ? LiteralUtil::ConvertF32ToBF16(literal)
+ : LiteralSlice(literal));
}
std::unique_ptr<GlobalData>
@@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
const Literal& literal) {
if (use_bfloat16_) {
- return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ return LiteralUtil::ConvertF32ToBF16(literal);
}
return literal.Clone();
}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 22dfdfb0e4..9d32f4f517 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test {
StatusOr<std::unique_ptr<GlobalData>> Execute(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// This executes the computation via the reference client (which connects a
// interpreter backend). The result is used as the expected values of the
// computation.
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
+ StatusOr<Literal> ExecuteAndTransferReference(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
- return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
+ return AddParam(LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+ return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
+ return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
// Executes the computation and calculates the expected reference value using
// the reference client. Returns two literals in the order of (expected,
// actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(XlaBuilder* builder,
- absl::Span<const Literal> arguments);
+ StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
+ XlaBuilder* builder, absl::Span<const Literal> arguments);
Client* client_;
Client* ref_client_; // To compute reference result.
@@ -412,9 +411,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -438,9 +435,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, absl::Span<const NativeT> expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -464,9 +459,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -490,9 +485,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -516,9 +511,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -542,13 +537,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR0(value);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -556,13 +551,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR1(values);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -570,13 +565,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -584,13 +579,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index c898dacf48..6f2ca84bb6 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
- auto computed, client_->Transfer(*data, &expected_literal->shape()));
+ auto computed, client_->Transfer(*data, &expected_literal.shape()));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
@@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- LiteralSlice(*result, {0}));
+ LiteralSlice(result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- LiteralSlice(*result, {1}));
+ LiteralSlice(result, {1}));
- EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
- EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+ EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 0),
+ ShapeUtil::GetTupleElementShape(result.shape(), 0),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 1),
+ ShapeUtil::GetTupleElementShape(result.shape(), 1),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})));
}
@@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
+ LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
Add(Parameter(&b, 0, shape, "param_0"),
@@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
- client_->Transfer(*results[0], &expected_result->shape()));
+ client_->Transfer(*results[0], &expected_result.shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 03d5696499..6ef7ca035f 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase {
absl::Span<GlobalData* const> arguments,
float expected_result, bool expect_cache_hit) {
ExecutionProfile execution_profile;
- std::unique_ptr<Literal> result =
+ Literal result =
client_
->ExecuteAndTransfer(computation, arguments,
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
->Execute(computation, arguments,
&execution_options_, &execution_profile)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data_handle).ConsumeValueOrDie();
+ Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
@@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
- client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
- client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 8226b6de3f..3b0414a604 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
- StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
- Client* client, const XlaOp& operand, XlaBuilder* builder,
- Layout* output_layout = nullptr) {
+ StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
+ XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
@@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
builder, nullptr));
- return literal->Get<Scalar>({});
+ return literal.Get<Scalar>({});
}
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<int32>({4, 6});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index be017477d8..9811a015e9 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, f32_scalar, "x");
@@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "z");
auto bcast = Broadcast(y, {5});
@@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "y");
auto y_bcast = Broadcast(y, {1, 5, 7});
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 25d10ab00a..32cac499c7 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
- LiteralUtil::CreateR0<float>(25.0f).get()}),
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
+ LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
@@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
- LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<bool>(true).get(),
- LiteralUtil::CreateR0<float>(12.2f).get(),
- LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(true),
+ LiteralUtil::CreateR0<float>(12.2f),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(46.6f).get(),
- LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
- .get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
- LiteralUtil::CreateR0<float>(9.3f).get()})
- .get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(46.6f),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
+ LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
@@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
- LiteralUtil::CreateR0<float>(b).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
@@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is true case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
@@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is false case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 4937574831..72ff1e74a4 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4D(input_array);
+ Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *input_literal);
+ ConstantLiteral(&builder, input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
@@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+ ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})}));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
+ Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
- LiteralSlice(*result, {0}), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+ LiteralSlice(result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateToken());
+ ConstantLiteral(&builder, LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 7a203d6873..5f063e6784 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, U32);
@@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 38b6da4fa9..fd98bf29b8 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
- client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+ client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index d2c6478b02..070b092d18 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
- auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
+ auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
- auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+ auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r5).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r5,
+ ComputeAndCompareLiteral(&builder, expected_r5,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
@@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(param0)),
- std::move(*LiteralUtil::CreateFromArray(param1))},
+ {LiteralUtil::CreateFromArray(param0),
+ LiteralUtil::CreateFromArray(param1)},
error_spec_);
}
@@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
- auto expected_r3 =
- expected_r1->Reshape({batch, num_windows, output_feature})
- .ConsumeValueOrDie();
+ auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
+ .ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+ client_->TransferToServer(input_r3).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r3,
+ client_->TransferToServer(filter_r3).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
@@ -891,9 +890,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
Array4D<float> filter_data(1, 1, 1, 2);
filter_data.FillIota(10);
- ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))});
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)});
+}
+
+XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
+ XlaBuilder builder(TestName());
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
+ Array4D<float> input_data(1, 64, 100, 100);
+ input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321);
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
+ Array4D<float> filter_data(7, 7, 1, 64);
+ input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = ConstantR4FromArray4D(&builder, filter_data);
+
+ // Specify bf01_01io->bf01 as dimension numbers.
+ ConvolutionDimensionNumbers dnums;
+ // Input
+ dnums.set_input_feature_dimension(1);
+ dnums.set_input_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ // Kernel
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ // Output
+ dnums.set_output_batch_dimension(0);
+ dnums.set_output_feature_dimension(1);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(3);
+ ConvGeneral(input, filter, /*window_strides=*/{1, 1},
+ /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
+ /*feature_group_count=*/64);
+
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
+ error_spec_);
}
class ConvolutionHloTest : public HloTestBase {};
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 6784c16715..ba3e9c436e 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
- weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto weights = ConstantLiteral(&builder, *weights_literal);
+ weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto weights = ConstantLiteral(&builder, weights_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto mirrored_weights = Rev(weights, {2, 3, 4});
ConvWithGeneralPadding(gradients, mirrored_weights,
/*window_strides=*/{1, 1, 1},
/*padding=*/{{0, 0}, {0, 0}, {1, 1}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
@@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
- activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
- auto activations = ConstantLiteral(&builder, *activations_literal);
+ activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
+ auto activations = ConstantLiteral(&builder, activations_literal);
auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto forward_conv =
ConvGeneralDilated(activations, gradients,
@@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder::CreateDefaultConvDimensionNumbers(
/*num_spatial_dims=*/3));
Transpose(forward_conv, {0, 1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 526626c1dd..1407e68d9a 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase {
protected:
void TestCopyOp(const Literal& literal) {
auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ auto constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase {
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
- TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+ TestCopyOp(LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
- TestCopyOp(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4(
+ TestCopyOp(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+ TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
@@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
// Copy literal to device to use as parameter.
auto literal = LiteralUtil::CreateR0<float>(42.0);
- Shape shape = literal->shape();
+ Shape shape = literal.shape();
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
@@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(module), {literal.get()});
- LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {&literal});
+ LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
@@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
- Layout* literal_layout =
- literal->mutable_shape_do_not_use()->mutable_layout();
+ Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
@@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
// The result of the computation has the default layout, which is the inverse
// of the layout of the source literal.
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
error_spec_);
}
@@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+ LiteralTestUtil::ExpectR3EqualArray3D(a, result);
}
void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
@@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+ Literal literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+ LiteralTestUtil::ExpectR4EqualArray4D(a, result);
}
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
@@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
XlaBuilder builder(TestName());
Parameter(&builder, 0, in_shape, "input");
- auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
+ auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index d12a4e7fcd..410732c07b 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
- EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+ EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
}
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
@@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(
- *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
}
// On the GPU backend, constants get special handling. Someone might pass a
@@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 6f7fc0e6e5..a693fa3595 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
@@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest,
@@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
module->AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D<float>(
- Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+ Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
class CustomCallClientAPITest : public ClientLibraryTestBase {};
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index eb15fc0593..e0f23b0fa8 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
// Try copying the elements back and comparing it
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
@@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
auto handles1 = result_status1.ConsumeValueOrDie();
auto handles2 = result_status2.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
handles1[0].reset();
handles1[1].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
@@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
// the same as handle[3] and handle[1] should be the same as handle[2].
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
@@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
// should not have been deallocated because of reference counting.
global_data.reset();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
/// Try deallocating one of the repeated elements, then copy
handles[0].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
@@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 5873516442..0171f51583 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
+ LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
auto lhs_handle =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
- param.dot_lhs_row_major)));
+ Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
- client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
- std::unique_ptr<Literal> dot_rhs_lit =
+ Literal dot_rhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
- client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> addend_data;
- std::unique_ptr<Literal> addend_lit;
+ Literal addend_lit;
std::unique_ptr<GlobalData> addend_handle;
if (param.has_addend) {
@@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() {
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
- addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
+ addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
}
XlaBuilder builder(TestName());
@@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
@@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
.ConsumeValueOrDie();
@@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
auto lhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*lhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*rhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
@@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 9bf3767ca3..7501c6d957 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie();
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*LiteralUtil::CreateR0(input_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(input_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_value =
- std::move(*LiteralUtil::CreateR0(update_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(update_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_value =
- std::move(*LiteralUtil::CreateR0(expected_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(expected_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
absl::Span<const int> expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR1(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR3FromArray3D<NativeT>(values);
- LOG(INFO) << name << ":" << literal->ToString();
+ Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+ LOG(INFO) << name << ":" << literal.ToString();
}
};
@@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- auto input = ConstantLiteral(&builder, *input_literal);
+ auto input = ConstantLiteral(&builder, input_literal);
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
@@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- stream.get(), *start_indices_literal, buffer));
+ stream.get(), start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index 5116e60ca6..b08ece0e63 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
- *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+ LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index bf1de02ba9..51b50d456e 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -38,29 +38,29 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&
i < known_incorrect_range.second) {
// If the operation is known to be buggy on a specific input clamp that
// input to 0 under the assumption that the op is at least correct on 0.
- input_literal->Set({i - begin}, 0.0f);
+ input_literal.Set({i - begin}, 0.0f);
} else {
- input_literal->Set({i - begin}, tensorflow::bit_cast<float, int>(i));
+ input_literal.Set({i - begin}, tensorflow::bit_cast<float, int>(i));
}
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
enqueue_op(&builder, input);
std::vector<float> expected_result;
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(evaluate_op(input_literal->Get<float>({i})));
+ expected_result.push_back(evaluate_op(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 7cb2f0cedf..9c94acb437 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -117,9 +117,9 @@ class FusionTest : public HloTestBase {
auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
}
@@ -222,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.
@@ -248,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -283,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
// Every element of result should be y = x^2 = 4.0.
for (int i = 0; i < rand_dim0_size; ++i) {
for (int j = 0; j < dim1_size; ++j) {
- EXPECT_EQ(4.0, result->Get<float>({i, j}));
+ EXPECT_EQ(4.0, result.Get<float>({i, j}));
}
}
}
@@ -308,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -323,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -338,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -353,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -368,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -383,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__) {
@@ -398,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -413,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -428,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -443,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reverse) {
@@ -459,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -477,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -495,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
@@ -513,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -535,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -552,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, TransposeNegate) {
@@ -570,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -602,8 +602,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kInput);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -624,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -674,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
// When a constant (or other op) which has multiple users is imported
@@ -710,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
@@ -782,19 +782,17 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> operand =
- LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+ Literal operand = LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text, config));
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
- test_runner_.Execute(std::move(module), {operand.get()},
- /*run_hlo_passes=*/false));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
+ test_runner_.Execute(std::move(module), {&operand},
+ /*run_hlo_passes=*/false));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
- *result));
+ LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ result));
}
class FusionClientLibraryTest : public ClientLibraryTestBase {};
@@ -821,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
// where overflow is OK.
Array2D<uint32> arr(32, 32);
arr.FillUnique();
- std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({0, 1}));
- std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({1, 0}));
- XlaOp p0 = AddParam(*l1, &b);
+ XlaOp p0 = AddParam(l1, &b);
XlaOp sum = p0;
for (int i = 1; i < kNumParams; ++i) {
- auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+ auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
sum = sum + p0 * pN * pN;
}
@@ -879,19 +877,19 @@ void BM_ParallelFusion(int num_iters) {
auto param0_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
- client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
- client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
- client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param2_literal, device_ordinal)
.ConsumeValueOrDie();
// Build executable.
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 6d63498044..daa89398a6 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -58,10 +58,10 @@ ENTRY main {
slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -79,10 +79,10 @@ ENTRY main {
slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -100,11 +100,10 @@ ENTRY main {
slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -122,11 +121,11 @@ ENTRY main {
slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -144,11 +143,11 @@ ENTRY main {
slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -166,13 +165,12 @@ ENTRY main {
slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -190,13 +188,12 @@ ENTRY main {
slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -214,10 +211,10 @@ ENTRY main {
slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -235,11 +232,10 @@ ENTRY main {
slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -257,9 +253,9 @@ ENTRY main {
slice_sizes={1, 0}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -281,11 +277,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -307,11 +303,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
+ Literal start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -333,11 +329,11 @@ ENTRY main {
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -359,11 +355,11 @@ ENTRY main {
ROOT result = u32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -381,10 +377,10 @@ ENTRY main {
slice_sizes={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -402,9 +398,9 @@ ENTRY main {
slice_sizes={1}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -422,10 +418,10 @@ ENTRY main {
slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -446,10 +442,10 @@ ENTRY main {
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -470,11 +466,10 @@ ENTRY main {
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -495,11 +490,11 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -520,13 +515,12 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest,
@@ -548,13 +542,12 @@ ENTRY main {
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -575,10 +568,10 @@ ENTRY main {
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -599,11 +592,10 @@ ENTRY main {
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), start_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -640,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> operand_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
@@ -657,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<xla::GlobalData>> result_data,
client_->ExecuteParallel(computation_instances));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
client_->Transfer(*(result_data[0])));
- LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}},
- *result_literal);
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}}, result_literal);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 3df99aac7d..bdd4fd7e3d 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -136,21 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() {
return debug_options;
}
-StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments);
}
-std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_
.Execute(std::move(module), arguments,
/*run_hlo_passes=*/false)
.ValueOrDie();
}
-std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
}
@@ -188,7 +188,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
TF_ASSIGN_OR_RETURN(auto reference,
reference_runner_.Execute(std::move(reference_module),
arguments, run_hlo_passes));
- return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
+ return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
error);
}
@@ -223,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompare(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -243,7 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -277,7 +276,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return test_runner_
.Execute(std::move(module_or_status.ValueOrDie()),
fake_argument_ptrs, /*run_hlo_passes=*/true)
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 21d77c0cc4..0ae4bdc104 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -115,16 +115,16 @@ class HloTestBase : public ::testing::Test {
}
// Executes the given module and return the result as a Literal.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Same as above, except the module will be executed without running any HLO
// passes on it.
- std::unique_ptr<Literal> ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
- std::unique_ptr<Literal> ExecuteAndTransfer(
- std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+ Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Executes the given hlo module on two backends and compares results.
//
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 96f72212f3..43cca91f64 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -155,20 +155,20 @@ class LiteralTestUtil {
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
absl::Span<const NativeT> expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
@@ -176,46 +176,46 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
absl::Span<const NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -223,7 +223,7 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -232,28 +232,28 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index 4151bfae03..b6f9b8156b 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -31,11 +31,11 @@ namespace xla {
namespace {
TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal literal = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
// death assertion.
auto unequal_things_are_equal = [] {
- std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal lhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(64).get(),
- LiteralUtil::CreateR0<int32>(42).get(),
+ Literal rhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(64),
+ LiteralUtil::CreateR0<int32>(42),
});
- CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+ CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
};
ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
}
@@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
auto two = LiteralUtil::CreateR0<float>(2);
auto four = LiteralUtil::CreateR0<float>(4);
ErrorSpec error(0.001);
- CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+ CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
};
tensorflow::Env* env = tensorflow::Env::Default();
@@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal_proto));
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
if (result.find("expected") != string::npos) {
- EXPECT_EQ("2", literal->ToString());
+ EXPECT_EQ("2", literal.ToString());
} else if (result.find("actual") != string::npos) {
- EXPECT_EQ("4", literal->ToString());
+ EXPECT_EQ("4", literal.ToString());
} else if (result.find("mismatches") != string::npos) {
- EXPECT_EQ("true", literal->ToString());
+ EXPECT_EQ("true", literal.ToString());
} else {
FAIL() << "unknown file in temporary directory: " << result;
}
@@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
- ::testing::AssertionResult result =
- LiteralTestUtil::Equal(*expected, *actual);
+ ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
EXPECT_THAT(result.message(),
::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
EXPECT_THAT(result.message(),
@@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
@@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
@@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b =
LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
- EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
- EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 237a4a361e..dbdd20daf0 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
int64 allocation_count_before = allocator_->allocation_count();
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
DefaultExecutableBuildOptions(), options);
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
@@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
computation, {}, ExecutableBuildOptions().set_device_ordinal(d),
ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 1a823cf189..a99b43f469 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
auto y = ConstantR0<float>(&builder, 123.0f);
Add(x, y);
- auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
+ auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
- LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
auto y = ConstantR1<float>(&builder, {});
Add(x, y);
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
- LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
@@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ExecutionProfile profile;
ScopedShapedBuffer result = ExecuteLocallyOrDie(
builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions().set_execution_profile(&profile));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
}
@@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
@@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
ScopedShapedBuffer result_colmaj =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with the parameter values in a different order.
ScopedShapedBuffer result_param_swap =
ExecuteLocallyOrDie(computation, {&y_array, &x_array});
- LiteralTestUtil::ExpectR2Near<float>(
- {{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_param_swap), error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ ShapedBufferToLiteral(result_param_swap),
+ error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
@@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
// Run with col-major result layout.
ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with row-major result layout.
@@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_rowmaj),
+ ShapedBufferToLiteral(result_rowmaj),
error_spec_);
}
@@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {2}));
+ LiteralSlice(result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 0}));
+ LiteralSlice(result_literal, {0, 0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {0, 1}));
+ LiteralSlice(result_literal, {0, 1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 2}));
+ LiteralSlice(result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
Tuple(&builder, {x, y});
auto array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
ExecutableBuildOptions options = DefaultExecutableBuildOptions();
Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
options, DefaultExecutableRunOptions());
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
Tuple(&builder, {array_sum, vector_diff});
auto computation = builder.Build().ConsumeValueOrDie();
- auto x_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
- auto y_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
- LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+ auto x_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
+ auto y_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
+ LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
- auto x_buffer = LiteralToShapedBuffer(*x_literal);
- auto y_buffer = LiteralToShapedBuffer(*y_literal);
+ auto x_buffer = LiteralToShapedBuffer(x_literal);
+ auto y_buffer = LiteralToShapedBuffer(y_literal);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
@@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
Tuple(&builder, {negate_array, vector_sum});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
+ LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
+ Literal result_0_literal = ShapedBufferToLiteral(result_0);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralSlice(*result_0_literal, {0}));
+ LiteralSlice(result_0_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
- LiteralSlice(*result_0_literal, {1}));
+ LiteralSlice(result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
- std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
+ Literal result_1_literal = ShapedBufferToLiteral(result_1);
LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
- LiteralSlice(*result_1_literal, {0}));
+ LiteralSlice(result_1_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
- LiteralSlice(*result_1_literal, {1}));
+ LiteralSlice(result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
// Feed in a tuple where each two-element vector element is {tuple_index,
// -tuple_index}.
- std::vector<std::unique_ptr<Literal>> arg_elements;
+ std::vector<Literal> arg_elements;
for (int i = 0; i < kElementCount; ++i) {
arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
}
- std::unique_ptr<Literal> arg_literal =
- LiteralUtil::MakeTupleOwned(std::move(arg_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
+ {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
}
}
@@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::vector<std::unique_ptr<Literal>> outer_tuple_elements;
+ std::vector<Literal> outer_tuple_elements;
for (int i = 0; i < kFanout; ++i) {
- std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
+ std::vector<Literal> inner_tuple_elements;
for (int j = 0; j < kFanout; ++j) {
inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
}
@@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
}
auto arg_literal =
LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
- LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
- error_spec_);
+ LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
+ LiteralSlice(result_literal, {i, j}),
+ error_spec_);
}
}
}
@@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
+ Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
for (int i = 0; i < kTupleDepth; ++i) {
- std::vector<std::unique_ptr<Literal>> arg_vector;
+ std::vector<Literal> arg_vector;
arg_vector.push_back(std::move(arg_literal));
arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
}
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
ShapeIndex index;
for (int i = 0; i < kTupleDepth; ++i) {
index.push_back(0);
}
LiteralTestUtil::ExpectR0Equal<float>(165.0,
- LiteralSlice(*result_literal, index));
+ LiteralSlice(result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {&x_array},
DefaultExecutableBuildOptions().set_result_layout(
@@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
DefaultExecutableRunOptions().set_device_ordinal(d));
EXPECT_EQ(d, result.device_ordinal());
LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ ShapedBufferToLiteral(result));
}
}
}
@@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
// As a check to verify that the computation ran of the device associated
// with the stream. This is a weak check, but stronger verification is hard.
EXPECT_EQ(d, result.device_ordinal());
- LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
}
}
@@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
+ Literal tuple_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
- LiteralSlice(*tuple_literal, {0}));
+ LiteralSlice(tuple_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
- LiteralSlice(*tuple_literal, {1}));
+ LiteralSlice(tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
@@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
executable_status.ConsumeValueOrDie();
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
->BlockHostUntilDone());
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
@@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
// Array shapes.
- test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
- test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
- test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
+ test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
+ test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
// Null shape (empty tuple).
- test_to_device_and_back(*LiteralUtil::MakeTuple({}));
+ test_to_device_and_back(LiteralUtil::MakeTuple({}));
// Non-nested tuples.
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(12223.0)}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}));
// Nested tuple.
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()})
- .get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
test_to_device_and_back(
- *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
+ LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<double>({1.0, -42.0}),
+ LiteralUtil::CreateR0<int64>(123456789000LL)}));
}
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
@@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
Add(in, constant);
- std::unique_ptr<Literal> result;
+ Literal result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
}));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
// Join the thread.
thread.reset();
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
@@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
[&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
local_client_->TransferFromOutfeedLocal(
shape, local_client_->default_device_ordinal()));
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
// Benchmark that measures the overhead of the LocalClient API when running a
@@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) {
auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
- ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
- buffer));
+ ASSERT_IS_OK(
+ transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
const int kWarmups = 2;
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index a8c68fc7fd..f90ef22d2d 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
.ConsumeValueOrDie();
}
-std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+Literal LocalClientTestBase::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
return local_client_->ShapedBufferToLiteral(shaped_buffer)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 90095c5d41..4027c7b124 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -86,8 +86,7 @@ class LocalClientTestBase : public ::testing::Test {
// Construct and return a literal containing the array represented by
// shaped_buffer.
- std::unique_ptr<Literal> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Execute the given computation on the local client. With and without
// options.
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 0732e195d4..4d327a6fe9 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
@@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
@@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
@@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) {
TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
@@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
@@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
{0});
@@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1});
@@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1, 2});
@@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) {
TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param2_literal =
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
+ Literal param2_literal =
LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
- client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param2_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
- auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
@@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
Add(x, y);
auto error_add = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, error_add, {0});
StatusOr<XlaComputation> computation_status = builder.Build();
@@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
@@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
@@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index edb592f43e..3f278115e0 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
});
Exp(data);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
{0.36788f, 1.64872f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
@@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
});
Map(&builder, {data}, add_half, {0, 1});
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
{-0.5f, 1.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
@@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
});
Max(lhs, rhs);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
{3.0f, -4.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6));
}
struct TestLinspaceMaxParam {
@@ -200,14 +200,12 @@ class MatOpsDotAddTest
TF_ASSERT_OK_AND_ASSIGN(
auto lhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
TF_ASSERT_OK_AND_ASSIGN(
auto rhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
XlaBuilder builder(TestName());
auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index c5e0b9b097..56aaeb0e68 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -114,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
+ Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
auto actual =
- ExecuteAndTransfer(std::move(hlo_module),
- {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
void RunTest1D(bool manual_fusion, int size) {
@@ -178,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase {
Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect =
- std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
};
@@ -218,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
LiteralUtil::CreateR0<float>(1.0)),
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
LiteralUtil::CreateR0<int32>(4)));
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -247,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
@@ -280,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
}
const char* const kScalarOps = R"(
@@ -324,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -356,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -389,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
- LiteralUtil::CreateR1<float>({36, 64}),
- LiteralUtil::CreateR1<float>({66, 138})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -422,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -456,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -492,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({14, 22}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -530,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest,
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
auto init1 = LiteralUtil::CreateR0<float>(5);
auto init2 = LiteralUtil::CreateR0<float>(6);
- std::unique_ptr<Literal> result = ExecuteNoHloPasses(
- std::move(module), {param.get(), init1.get(), init2.get()});
+ Literal result =
+ ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -565,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest,
auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
LiteralUtil::CreateR3<Eigen::half>(
@@ -576,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest,
{Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)},
{Eigen::half(7), Eigen::half(8)}}})),
- *result));
+ result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
index 0a0426adcb..f2460822a6 100644
--- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
+++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
@@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
GetTupleElement(result_tuple, 0);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
VLOG(1) << "Transferring trip count to computation";
// Transfer number of iterations to Infeed.
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<int32_t>(1)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
// Pick up value from outfeed
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 1);
+ EXPECT_EQ(r.Get<int32>({}), 1);
}
VLOG(1) << "Writing data to infeed";
// Transfer some stuff to Infeed for use inside of loop.
TF_ASSERT_OK(local_client_->TransferToInfeed(
- *LiteralUtil::CreateR1<int32_t>({10, 20})));
+ LiteralUtil::CreateR1<int32_t>({10, 20})));
// Pick up value from outfeed
{
VLOG(1) << "Reading from body outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&xfeed_shape));
- EXPECT_EQ(r->Get<int32>({0}), 11);
- EXPECT_EQ(r->Get<int32>({1}), 21);
+ EXPECT_EQ(r.Get<int32>({0}), 11);
+ EXPECT_EQ(r.Get<int32>({1}), 21);
}
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 0);
+ EXPECT_EQ(r.Get<int32>({}), 0);
}
// Joins the thread
thread.reset();
- EXPECT_EQ(comp_result->Get<int32>({}), 0);
+ EXPECT_EQ(comp_result.Get<int32>({}), 0);
}
XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
@@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
}));
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&result_shape));
- EXPECT_EQ(r->Get<bool>({}), true);
+ EXPECT_EQ(r.Get<bool>({}), true);
// Join the thread
thread.reset();
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index cbeddffacf..6e98167739 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(0);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
}
@@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
dimension->set_edge_padding_high(4);
dimension->set_interior_padding(7);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
DefaultErrorSpec());
}
@@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(1);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
}
@@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
XlaBuilder b(TestName());
Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
- AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
DefaultErrorSpec());
@@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
});
input->FillWithYX(input_xy);
- Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
auto expected = absl::make_unique<Array4D<float>>(2, 3, 3, 2);
@@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
Pad(AddParam(input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b),
r4_padding_on_dim0_dim1_);
auto expected = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
@@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
const float pad_value = -5.123f;
Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 1, 5, 8);
expected_array.Fill(pad_value);
@@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
input_array(0, 24, 6, 6) = 2.0f;
input_array(0, 17, 2, 5) = 3.0f;
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 25, 17, 11);
expected_array.Fill(pad_value);
@@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
- padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(3.14f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) {
XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
- Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
+ Reduce(input, AddParam(LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
- padding_config);
+ Pad(reduce, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index f6c762e7a4..dcb4c11c3c 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR0<float>(3.14159f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
@@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
@@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
@@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+ Literal param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0,
ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
@@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
@@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
@@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
// Use both parameters
//
@@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+ Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
@@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ Parameter(&builder, 1, literal1.shape(), "param1");
ComputeAndCompareR1<float>(&builder, {10, 20},
{param0_data.get(), param1_data.get()},
@@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 =
- LiteralUtil::CreateR1<float>({10, 20, 30});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&builder, 2, literal1->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
// This add is unused.
Add(param1, param2);
@@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+ Literal literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
- client_->TransferToServer(*literal).ConsumeValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ client_->TransferToServer(literal).ConsumeValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest,
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
+ Literal literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
sum_handle = Add(sum_handle, param);
}
@@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
- parameter_shapes.push_back(literal->shape());
+ parameter_shapes.push_back(literal.shape());
}
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
+ Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
+ std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
XlaOp bool_param =
- Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param");
+ Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
params.push_back(bool_param);
- parameter_shapes.push_back(bool_literal->shape());
+ parameter_shapes.push_back(bool_literal.shape());
auto init = Tuple(&builder, params);
@@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest,
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR1<float>({4, 5, 6}),
}))
.ConsumeValueOrDie();
@@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ Literal literal = LiteralUtil::CreateR2<float>({
{1, 3},
{2, 4},
});
- const Shape original = literal->shape();
+ const Shape original = literal.shape();
{
// Reverse the layout present in original, and make that the layout of the
// literal.
@@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
original.layout().minor_to_major().begin(),
original.layout().minor_to_major().end());
std::reverse(original_layout.begin(), original_layout.end());
- *literal->mutable_shape_do_not_use()->mutable_layout() =
+ *literal.mutable_shape_do_not_use()->mutable_layout() =
LayoutUtil::MakeLayout(original_layout);
- ASSERT_EQ(2, literal->Get<float>({0, 1}));
+ ASSERT_EQ(2, literal.Get<float>({0, 1}));
}
// Use the original shape in building the computation.
XlaBuilder builder(TestName());
@@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
Slice(input, {0, 1}, {1, 2}, {1, 1});
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
// Check that we got the off-diagonal value that we expected.
Array2D<float> expected(1, 1);
expected(0, 0) = 2;
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 5f322b768d..8f2c26f0ee 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -37,8 +37,7 @@ namespace {
class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
- std::unique_ptr<Literal> UniformTest(T a, T b, absl::Span<const int64> dims,
- int64 seed = 42);
+ Literal UniformTest(T a, T b, absl::Span<const int64> dims, int64 seed = 42);
// Computes the χ² statistic of a sample of the discrete uniform distribution
// of the given range size. `expected_count` is the number of times each
@@ -49,9 +48,8 @@ class PrngTest : public ClientLibraryTestBase {
};
template <typename T>
-std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
- absl::Span<const int64> dims,
- int64 seed) {
+Literal PrngTest::UniformTest(T a, T b, absl::Span<const int64> dims,
+ int64 seed) {
XlaBuilder builder(TestName());
RngUniform(
ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
@@ -60,8 +58,8 @@ std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
SetSeed(seed);
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
- EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
- actual->EachCell<T>([=](absl::Span<const int64>, T value) {
+ EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions()));
+ actual.EachCell<T>([=](absl::Span<const int64>, T value) {
EXPECT_LE(a, value);
EXPECT_LT(value, b);
});
@@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
constexpr int64 count = 100;
for (int64 seed = 0; seed < count; ++seed) {
auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
- result->Literal::EachCell<bfloat16>(
- [&](absl::Span<const int64>, bfloat16 value) {
- int64 index = static_cast<int64>((value - low) / interval);
- counts[index]++;
- });
+ result.EachCell<bfloat16>([&](absl::Span<const int64>, bfloat16 value) {
+ int64 index = static_cast<int64>((value - low) / interval);
+ counts[index]++;
+ });
}
// Each bucket should have similar amount of counts. That is, not more than
// 10% of total counts. This mostly tests that we don't fall into a 1:2:2
@@ -149,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
std::vector<int32> counts(range_size, 0);
- actual->EachCell<int32>(
+ actual.EachCell<int32>(
[&counts](absl::Span<const int64>, int32 value) { ++counts[value]; });
int64 sum = 0;
for (int32 i = 0; i < range_size; ++i) {
@@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
};
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
- client_->TransferToServer(*param0_literal));
+ client_->TransferToServer(param0_literal));
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto fn = build_sum_rng(builder);
Map(&builder, {param0}, fn, {0});
@@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
computation,
/*arguments=*/{param0_data.get()}, &execution_options));
- EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
- ShapeUtil::ElementsIn(param0_literal->shape()));
- for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
- EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
- EXPECT_LT(actual->data<float>()[i],
- param0_literal->data<float>()[i] + 1.0f);
+ EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()),
+ ShapeUtil::ElementsIn(param0_literal.shape()));
+ for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) {
+ EXPECT_GE(actual.data<float>()[i], param0_literal.data<float>()[i]);
+ EXPECT_LT(actual.data<float>()[i], param0_literal.data<float>()[i] + 1.0f);
}
}
@@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
ExecutionOptions execution_options2 = execution_options_;
execution_options2.set_seed(65);
- std::unique_ptr<Literal> result1;
+ Literal result1;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
&execution_options1));
}
- std::unique_ptr<Literal> result2;
- std::unique_ptr<Literal> result3;
+ Literal result2;
+ Literal result3;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options1));
}
- std::unique_ptr<Literal> result4;
- std::unique_ptr<Literal> result5;
- std::unique_ptr<Literal> result6;
+ Literal result4;
+ Literal result5;
+ Literal result6;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options_));
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index 9af9ea4a22..c9096fb29b 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
*reduce_input_shape->mutable_layout() =
LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
- std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+ Literal reduce_input = LiteralUtil::CreateR4<float>(
{{ /*i0=0*/
{/*i1=0*/
{-0.246092796, -0.179497838, -0.161181688},
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 0916a07f4f..26e2bfde5c 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR1<float>({input_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
ReducePrecision(a, exponent_bits, mantissa_bits);
@@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// Abs doesn't affect resolution.
auto abs = Abs(a);
@@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 57f7fed61f..83997cdac2 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase {
}, 4);
// clang-format on
CHECK(ShapeUtil::Equal(
- literal_3d_->shape(),
+ literal_3d_.shape(),
ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
- << literal_3d_->shape().ShortDebugString();
+ << literal_3d_.shape().ShortDebugString();
}
// Runs an R1 => R0 reduction test with the given number of elements.
@@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase {
input_data[i] *= -1;
}
}
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1(AsSlice(input_data));
+ Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (float item : input_data) {
@@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase {
Reduce(pred_values, init_value, reduce,
/*dimensions_to_reduce=*/{0});
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
+ Literal input_literal = LiteralUtil::CreateR1(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
bool expected = and_reduce;
for (bool item : input_data) {
@@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<uint8> input_data(rows, cols);
input_data.FillRandom(0, 1);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::array<bool, cols> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (int64 rowno = 0; rowno < rows; ++rowno) {
@@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<NativeT> input_data(rows, cols);
input_data.FillUnique(initial_value);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
// NativeT can be bool, and std::vector<bool> does not convert to
// Span.
@@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase {
reference_reduction_function_for_uints, unsigned_int_identity);
}
- std::unique_ptr<Literal> literal_2d_;
- std::unique_ptr<Literal> literal_3d_;
+ Literal literal_2d_;
+ Literal literal_3d_;
uint32 seed_ = 0xdeadbeef;
};
@@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
- MakeFakeLiteral(input_shape));
+ TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
- ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4));
+ ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
}
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
@@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3D(input_data);
+ Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 major = 0; major < 2; ++major) {
@@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
auto input_max = FLT_MIN;
input.Each(
@@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
auto input_min = FLT_MAX;
@@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 1, {});
}
@@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
// Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XlaBuilder builder("reduce_among_y");
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
@@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
@@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
@@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
+ input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
XlaComputation add = CreateScalarAddComputation(F32, &builder);
Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
GetParam().reduce_dims);
@@ -873,11 +864,10 @@ XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
auto a = ConstantR0<float>(&builder, 2.0f);
auto a2 = Abs(a);
- std::unique_ptr<Literal> b_literal =
- LiteralUtil::CreateR1<float>({1.0f, 4.0f});
+ Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b = Parameter(&builder, 0, b_literal->shape(), "b");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b = Parameter(&builder, 0, b_literal.shape(), "b");
Reduce(b, a2, max_f32, {0});
ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
@@ -904,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest {
std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
auto input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
- max_fn, {0});
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
+ {0});
ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
}
@@ -952,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) {
float operand[] = {42.0f};
float init = 58.5f;
float expected = 42.0f;
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1<float>(operand);
+ Literal input_literal = LiteralUtil::CreateR1<float>(operand);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
std::unique_ptr<GlobalData> input_global_data2 =
- client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal2).ConsumeValueOrDie();
ComputeAndCompareR0<float>(
&builder, expected, {input_global_data.get(), input_global_data2.get()},
ErrorSpec(0.0001));
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index a1001296a1..63491a90bf 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -73,7 +73,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+ auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
&builder_);
ReduceWindow(input, init,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
TF_ASSERT_OK(builder_.first_error());
ReduceWindow(input, init_value,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
/*window_dimensions=*/{},
/*window_strides=*/{}, Padding::kSame);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
{}, ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Padding::kSame);
ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+ LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
{}, ErrorSpec(0.00001));
}
@@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
{1, 2, 2, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests a reduction function that is not a simple add/min/max/etc.
@@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
Min(Add(lhs, rhs),
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
XlaComputation reduce_fn = b->BuildAndNoteError();
ReduceWindow(
input,
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
reduce_fn,
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
@@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
/*window=*/{1, 1, 2, 1},
/*stride=*/{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
{}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
Padding padding = Padding::kSame;
ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
@@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
{1, 4, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- auto arg_literal = absl::make_unique<Literal>(shape);
- arg_literal->PopulateWithValue(1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ Literal arg_literal(shape);
+ arg_literal.PopulateWithValue(1.0f);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
@@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
Shape result_shape =
ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
- auto expected = absl::make_unique<Literal>(result_shape);
- expected->PopulateWithValue(27.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ Literal expected(result_shape);
+ expected.PopulateWithValue(27.0f);
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 1;
int stride = 8;
@@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
Array4D<float> input_array(3, 2, 4, 64);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 3;
int stride = 1;
@@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
Array4D<float> input_array(1, 3, 12, 200);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 8;
int stride = 5;
@@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
std::vector<float> input_vector(128 * 9, 1);
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
ComputeAndCompareLiteral(
&builder_,
- *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+ LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
DefaultErrorSpec());
}
@@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(
input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
@@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -594,7 +588,7 @@ string R4ReduceWindowTestDataToString(
// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -614,11 +608,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], param.base_bounds[3]);
input.FillRandom(0.1f, 0.1f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(4);
@@ -627,7 +620,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
auto reducer = param.reducer;
if (use_bfloat16() && Product(param.window_bounds) > 128) {
@@ -659,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides,
/*padding=*/padding);
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateFromArray(*expected);
+ Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
- input_literal->shape().element_type(),
- AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
- ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
+ input_literal.shape().element_type(),
+ AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
+ ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
DefaultErrorSpec(), &expected_shape_with_layout);
}
};
@@ -988,7 +980,7 @@ string R3ReduceWindowTestDataToString(
param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1008,12 +1000,11 @@ TEST_P(R3ReduceWindowTest, DoIt) {
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2]);
input.FillRandom(0.1f, 0.1f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
auto reducer = param.reducer;
if (use_bfloat16()) {
- input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal);
+ input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
if (Product(param.window_bounds) > 128) {
// To avoid numerical issues, force the reducer to be kMax for large bf16
// windows.
@@ -1021,9 +1012,9 @@ TEST_P(R3ReduceWindowTest, DoIt) {
}
}
- XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input");
+ XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
auto computation = reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
@@ -1035,7 +1026,7 @@ TEST_P(R3ReduceWindowTest, DoIt) {
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides, /*padding=*/param.padding);
- ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec());
+ ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(
@@ -1130,7 +1121,7 @@ string R2ReduceWindowTestDataToString(
param.layout[1], //
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1147,12 +1138,11 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(2);
for (int i = 0; i < 2; ++i) {
@@ -1162,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1178,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
};
@@ -1332,7 +1322,7 @@ string R1ReduceWindowTestDataToString(
"__pad_high_", absl::StrJoin(param.pad_high, "x"),
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1352,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) {
const float kInitValue = 0.0f;
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
+ auto input_arg =
+ CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
std::vector<std::pair<int64, int64>> padding(1);
padding[0] = {param.pad_low[0], param.pad_high[0]};
@@ -1365,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1384,7 +1374,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index d891451381..5cf87e565b 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect 4.
- LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(4, literal);
}
XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
@@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(2))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(3))
.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed,
/*arguments=*/{x_data.get(), y_data.get()},
@@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
.ConsumeValueOrDie();
// Expect 5.
- LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(5, literal);
}
TEST_F(ReplayTest, MapPlusTwoOverR1) {
@@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect result.
- LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+ LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, literal);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 17d12715f6..dedc95b5ae 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{});
auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(1.0f);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
auto a = Neg(parameter);
Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) {
Array2D<float> input_array(0, 3);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) {
Array2D<float> input_array(3, 0);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
auto input_literal =
LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 3});
auto expected_literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
auto input_literal = LiteralUtil::CreateFromArray(*simple);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{3, 1});
auto expected = ReferenceUtil::TransposeArray2D(*simple);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 4});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 3, 0, 0});
auto expected_literal =
LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{24, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 6});
auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{2, 6});
Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
{8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
auto expected_literal = LiteralUtil::CreateFromArray(expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{8, 3});
@@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
{35, 36, 37},
{40, 41, 42},
{45, 46, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{8, 3});
@@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
{45, 16, 26},
{36, 46, 17},
{27, 37, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{2, 6, 2});
auto expected_literal = LiteralUtil::CreateR3<float>(
{{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
{{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
t2x2x2x3.FillWithYX(*filler2x3);
auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
auto expected_literal = LiteralUtil::CreateR2<float>(
{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 1, 1) = 7;
auto input_literal = LiteralUtil::CreateFromArray(t);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{2, 4});
auto expected_literal =
LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
Reshape(parameter, dimensions, {});
auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
- ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&b, expected_literal, {input.get()},
zero_error_spec_);
}
}
@@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {}, {});
EXPECT_THAT(
@@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {1}, {});
EXPECT_THAT(ExecuteToString(&b, {}),
@@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
LayoutUtil::MakeLayout({0, 1, 2, 3}));
// clang-format on
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
@@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
{1, 0});
- std::unique_ptr<Literal> actual =
+ Literal actual =
client_
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+ Literal expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = LiteralUtil::ConvertF32ToBF16(*expected);
+ expected = LiteralUtil::ConvertF32ToBF16(expected);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
@@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
{{204, 205, 206, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
@@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
{{206, 7, 107, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -691,17 +690,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
Array4D<float> input(2, 1, 1, 1);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -712,17 +709,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
Array4D<float> input(2, 1, 4, 1);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -734,12 +729,11 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
Array4D<float> input(5, 10, 2, 3);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
/*new_sizes=*/{5, 60});
@@ -749,7 +743,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
*cell;
});
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -761,12 +755,11 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
input_array.Each(
[&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
/*new_sizes=*/{7, 2, 3, 5});
XlaComputation computation = builder.Build().ConsumeValueOrDie();
@@ -775,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
{2, 3, 0, 1});
- std::unique_ptr<Literal> output_literal =
+ Literal output_literal =
client_
->ExecuteAndTransfer(computation, {input_data.get()},
&execution_options)
@@ -784,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
- EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
+ auto expected = LiteralUtil::ConvertF32ToBF16(input_literal);
+ EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
} else {
- EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
+ EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
}
}
@@ -798,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()});
+ ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
@@ -813,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
/*new_sizes=*/{2, 4, 3, 1});
@@ -830,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
{{16}, {20}, {24}}}});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()});
+ ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
@@ -841,24 +834,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
@@ -869,24 +861,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
@@ -897,24 +888,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
@@ -926,24 +916,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
@@ -954,24 +943,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
- ->Relayout(input_literal->shape().layout());
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal)
+ .Relayout(input_literal.shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 74ded82ddf..4e55b0d7ac 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -83,25 +83,25 @@ TEST_P(FloatReverseTest, Reverses) {
ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
std::iota(input_vector.begin(), input_vector.end(), 0.0);
auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
- auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
+ auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto a = AddParam(*input_literal, &builder);
+ auto a = AddParam(input_literal, &builder);
Rev(a, spec.reversal);
- std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
+ Literal expected = input_literal.Clone();
std::vector<int64> output_indices(spec.input_dims.size());
- expected->EachCell<float>([&](absl::Span<const int64> indices, float) {
+ expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
for (int64 i = 0; i < indices.size(); ++i) {
output_indices[i] = indices[i];
}
- float value = input_literal->Get<float>(indices);
+ float value = input_literal.Get<float>(indices);
for (int64 dim : spec.reversal) {
output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
}
- expected->Set<float>(output_indices, value);
+ expected.Set<float>(output_indices, value);
});
- ComputeAndCompareLiteral(&builder, *expected, {});
+ ComputeAndCompareLiteral(&builder, expected, {});
}
INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index e692b8c5d5..091a5d2cac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -38,7 +38,7 @@ namespace {
class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
protected:
// Sends the literal to the server and retrieves it back.
- std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ Literal RoundTripToServer(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
return client_->Transfer(*data).ConsumeValueOrDie();
@@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
+ Literal actual =
reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0, actual->Get<float>({0}));
- EXPECT_EQ(24.0, actual->Get<float>({1}));
+ EXPECT_EQ(42.0, actual.Get<float>({0}));
+ EXPECT_EQ(24.0, actual.Get<float>({1}));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
@@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(64.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(64.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(64.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(64.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index a8193c2eac..cd5a531603 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
void RoundTripTest(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data).ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
+ Literal result = client_->Transfer(*data).ConsumeValueOrDie();
+ EXPECT_TRUE(LiteralTestUtil::Equal(original, result));
}
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+ RoundTripTest(LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+ RoundTripTest(LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(
- *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*LiteralUtil::CreateR4<float>({{
+ RoundTripTest(LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*LiteralUtil::MakeTuple({}));
+ RoundTripTest(LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
- LiteralUtil::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(1.0), LiteralUtil::CreateR1<int>({2, 3})}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 07460a7e01..1dd937a6d0 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -161,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
+ Literal a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
{a_data.get()});
}
@@ -225,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+ Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> c_data =
- client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+ client_->TransferToServer(c_literal).ConsumeValueOrDie();
- XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a");
- XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b");
- XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c");
+ XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
+ XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
+ XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
Mul(Mul(a, b), c);
ComputeAndCompareR0<float>(&builder, 5.775f,
@@ -377,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(div_computation,
@@ -388,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend / divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -419,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(rem_computation,
@@ -430,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend % divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -441,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
- TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
+ Literal literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index 1858dcea61..d20dba028a 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -62,13 +62,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
@@ -92,13 +90,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
@@ -123,13 +120,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
@@ -154,13 +149,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
@@ -185,13 +178,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
@@ -216,13 +208,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
@@ -247,13 +237,12 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
@@ -277,15 +266,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
@@ -309,15 +296,13 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
@@ -341,12 +326,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
@@ -370,13 +354,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ZeroDimBounds) {
@@ -400,11 +382,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
@@ -429,12 +410,11 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
@@ -458,13 +438,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
@@ -488,13 +468,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NegativeIndex) {
@@ -518,13 +498,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OneScalarIndex) {
@@ -548,12 +528,12 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates =
LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScalarUpdate) {
@@ -577,10 +557,10 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, EmptyIndices) {
@@ -604,10 +584,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
+ Literal updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index c9a58aefb4..a40c2d7de6 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
XlaBuilder builder(TestName());
auto original = ConstantR4FromArray4D(&builder, values);
Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
- &expected_literal->shape());
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
+ &expected_literal.shape());
}
struct R1Spec {
@@ -201,7 +201,7 @@ class SliceR1Test : public ClientLibraryTestBase,
auto literal = LiteralUtil::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
- auto original = Parameter(&builder, 0, literal->shape(), "p0");
+ auto original = Parameter(&builder, 0, literal.shape(), "p0");
Slice(original, {spec.slice_start}, {spec.slice_limit},
{spec.slice_stride});
@@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase,
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
}
};
@@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) {
input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
- auto a = Parameter(&builder, 0, literal->shape(), "p0");
+ auto a = Parameter(&builder, 0, literal.shape(), "p0");
Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
@@ -467,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase,
XlaBuilder builder(TestName());
auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
- auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
+ auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 3ae31191a0..5155f0c652 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
// are not supported and uniqueness cannot be guaranteed if the number of
// elements exceeds the number of different values supported by the type.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
+StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> element,
+ Literal element,
MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
@@ -131,60 +132,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
- auto literal = absl::make_unique<Literal>(shape);
+ Literal literal(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<half>(&literal, engine,
no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<float>(&literal, engine,
no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+ PopulateWithRandomFloatingPointData<double>(&literal, engine,
no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
- no_duplicates);
+ PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(
- literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) {
+ literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
return generator(*engine);
}));
break;
@@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space,
- std::minstd_rand0* engine) {
+Literal MakeRandomIndex(absl::Span<const int64> index_space,
+ std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
@@ -293,7 +286,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
// no constrained uses in the dataflow graph. If such constraints exist,
// generate a constrained literal (either bounded in the case of indices, or
// zero in the case of init_values for reductions).
-StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
+StatusOr<Literal> CreateLiteralForConstrainedUses(
const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
@@ -358,9 +351,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
- return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type());
case ConstantType::kOne:
- return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type());
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
@@ -374,34 +367,33 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
// Given a module entry parameter, use the dataflow analysis to see if a
// special case literal must be created, or if we can generate fake data.
-StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
- const HloDataflowAnalysis& dataflow, const HloInstruction& param,
- std::minstd_rand0* engine) {
+StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
+ const HloInstruction& param,
+ std::minstd_rand0* engine) {
const auto constrained_uses = FindConstrainedUses(dataflow, param);
return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
}
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random) {
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random) {
auto engine =
pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeArguments(module, engine.get());
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- std::vector<std::unique_ptr<Literal>> arguments(params.size());
+ std::vector<Literal> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
arguments[i] =
MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index a260271b1b..b3c8a73905 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -57,8 +57,8 @@ class PseudorandomGenerator {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data
// generation. See below for documentation of pseudo_random.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random = true);
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
+ bool pseudo_random = true);
// Generates a vector of arguments containing fake data. The number, shape and
// layout of the arguments is appropriate for given HLO module.
@@ -84,14 +84,14 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
// TODO(b/79942829): Make interesting argument generation fast enough that using
// pseudo_random does not save any noticeable amount of time so that the
// parameter can be removed.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random = true);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random = true);
// Overload which accepts a random number generator. This enables generation of
// different random values with sequential calls to MakeFakeArguments by reusing
// the same generator.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, std::minstd_rand0* engine);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine);
// Check that a given module satisfies various constraints before trying to
// execute it.
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 322c8ef090..181e5cbe29 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -85,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -114,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -140,10 +140,10 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
}
)")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 2);
- const Literal& key_arg = *args[0];
+ const Literal& key_arg = args[0];
tensorflow::gtl::FlatSet<uint32> key_set;
for (const float& value : key_arg.data<float>()) {
@@ -163,10 +163,10 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
}
)")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 2);
- const Literal& key_arg = *args[0];
+ const Literal& key_arg = args[0];
tensorflow::gtl::FlatSet<int32> key_set;
for (const int32& value : key_arg.data<int32>()) {
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index c7eb9e2dbe..b34fd0f2e8 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, TokenTree) {
@@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(true);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(42, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(42, result.Get<int32>({}));
}
{
@@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(false);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(7, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(7, result.Get<int32>({}));
}
}
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 125513ddfd..d6641d257a 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase {
};
XLA_TEST_F(TransferManagerTest, TransferR0U32) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR0<uint32>(42);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
+ LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
- *result);
+ result);
}
XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
std::vector<float> test_vector(1024 * 1024);
std::iota(test_vector.begin(), test_vector.end(), 0);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1<float>(test_vector);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
+ LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1U8) {
const char* test_string = "0123456789abcdef";
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1U8(test_string);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_EQ(result->GetR1U8AsString(), test_string);
+ EXPECT_EQ(result.GetR1U8AsString(), test_string);
}
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest,
TransferR2F32AndChangeLayoutTransferringToDevice) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
const Shape ondevice_shape =
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest,
// Round trip literal through device. Set the on-device layout to something
// different than the literal layout.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_FALSE(
- LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
+ LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest, TransferTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTuple({});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
+ Literal literal = LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<complex64>(
- {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
- .get(),
- LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
- LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
+ LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
@@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
// supported.
auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
}
XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
const int64 kIterationCount = 5000;
- std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(456.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
-
- auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
- auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+ Literal literal1 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ Literal literal2 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(456.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
+
+ auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
+ auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
auto stream1 = stream_;
auto stream2 = stream_->GetOrCreateSubStream();
- std::unique_ptr<Literal> result1, result2;
+ Literal result1, result2;
// Round trip literals through device in multiple streams asynchronously.
for (int i = 0; i < kIterationCount; ++i) {
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
device_buffer1));
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
device_buffer2));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result1,
+ Literal this_result1,
transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result2,
+ Literal this_result2,
transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
result1 = std::move(this_result1);
result2 = std::move(this_result2);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
}
class TransferDeviceToHostBenchmark : public TransferManagerTest {
@@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
}
tensorflow::testing::StopTiming();
@@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
}
tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index f2b3b49015..619d2a388b 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
@@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
- LiteralUtil::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar1),
+ LiteralUtil::CreateR0<float>(constant_scalar2)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) {
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
- LiteralUtil::CreateR1<float>({}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
@@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
auto expected = LiteralUtil::MakeTuple({});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
@@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>(constant_matrix).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>(constant_matrix),
+ LiteralUtil::CreateR1<float>(constant_vector)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
@@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
- LiteralUtil::CreateR0<bool>(!direction).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(direction),
+ LiteralUtil::CreateR0<bool>(!direction)});
- ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
+ ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
@@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, TuplesInAMap) {
@@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
- LiteralUtil::CreateR1<float>(vec2).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
@@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
@@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) {
auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ LiteralUtil::MakeTuple({&expected_v1, &expected_s});
auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
- auto expected =
- LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
@@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::MakeTuple(
- {
- LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
- })
- .get(),
- LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
+ }),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
}))
.ConsumeValueOrDie();
@@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
- .get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<complex64>({1, 2}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
- {{10000, 20000}, {30000, 40000}}})
- .get()})
- .get()}))
+ {{10000, 20000}, {30000, 40000}}})})}))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
->TransferToServer(
- *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
auto sum =
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
- auto prod = absl::make_unique<Literal>(sum->shape());
- ASSERT_TRUE(prod->Populate<complex64>(
- [&sum](absl::Span<const int64> indexes) {
- return sum->Get<complex64>(indexes) *
- (indexes[indexes.size() - 1] == 0
- ? complex64(1, 2)
- : complex64(1, -2));
- })
+ Literal prod(sum.shape());
+ ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
+ return sum.Get<complex64>(indexes) *
+ (indexes[indexes.size() - 1] == 0
+ ? complex64(1, 2)
+ : complex64(1, -2));
+ })
.ok());
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
- LiteralUtil::CreateR0<complex64>({123, 456}).get()});
- ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices({prod, sum}),
+ LiteralUtil::CreateR0<complex64>({123, 456})});
+ ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
.ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
- auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
+ auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+ result));
}
// Disabled on interpreter due to lack of outfeed.
@@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest,
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
TF_EXPECT_OK(Execute(std::move(module),
- {param0.get(), param1.get(), param1.get(),
- param0.get(), param4.get()})
+ {&param0, &param1, &param1, &param0, &param4})
.status());
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
- auto literal = Literal::CreateFromShape(expected->shape());
+ auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- backend().default_stream_executor(), expected->shape(), *literal));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
+ backend().default_stream_executor(), expected.shape(), literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 8f80a9f3e4..4fbd7f2fb1 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper<complex64>() {
{-inf<float>(), 0}});
Abs(arg);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper<complex64>() {
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
Sign(arg);
- std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
+ Literal expected = LiteralUtil::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper<complex64>() {
auto abs = Abs(arg);
Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
@@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) {
Add(sgnc, ConvertElementType(
Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, SignTestR1) {
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf1867b9..7abd8651d5 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// have all reached 2.0.
auto expected_data =
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = LiteralUtil::MakeTuple({expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
@@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple(
+ {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
@@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
@@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
- auto expected = LiteralUtil::MakeTuple(
- {expected_counter.get(), expected_predicate.get()});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
+ auto expected =
+ LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
}
TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
@@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR0<int32>(7);
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests two while nodes when the result type T is a Tuple and the second
@@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests a while node when the result type T is a vector of S32.
@@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({&expected_element, &expected_element});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
auto add1 = LiteralUtil::CreateR0<int32>(15);
auto add2 = LiteralUtil::CreateR0<int32>(16);
- auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ auto expected = LiteralUtil::MakeTuple({&add1, &add2});
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
GetTupleElement(while_instruction, 3);
TF_ASSERT_OK_AND_ASSIGN(
- auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
{{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
@@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
XlaBuilder builder(TestName());
While(condition, body, ConstantR0<int32>(&builder, 0));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
ComputeAndCompareR0<int32>(&builder, 2, {});
}
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7fd42944de..db5a824de0 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
transfer_manager->AllocateScopedShapedBuffer(
lhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer rhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
rhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 442e66321e..cdde88c135 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -39,8 +39,7 @@ limitations under the License.
namespace xla {
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
- absl::string_view path) {
+StatusOr<Literal> TextLiteralReader::ReadPath(absl::string_view path) {
CHECK(!absl::EndsWith(path, ".gz"))
<< "TextLiteralReader no longer supports reading .gz files";
std::unique_ptr<tensorflow::RandomAccessFile> file;
@@ -57,7 +56,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
: file_(file) {}
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+StatusOr<Literal> TextLiteralReader::ReadAllLines() {
tensorflow::io::RandomAccessInputStream stream(file_.get());
tensorflow::io::BufferedInputStream buf(&stream, 65536);
string shape_string;
@@ -74,9 +73,9 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
ShapeUtil::HumanString(shape));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
- result->PopulateWithValue<float>(fill);
+ result.PopulateWithValue<float>(fill);
std::vector<absl::string_view> pieces;
std::vector<absl::string_view> coordinates;
std::vector<int64> coordinate_values;
@@ -116,7 +115,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
"\"%s\"",
shape.dimensions_size(), coordinate_values.size(), line);
}
- result->Set<float>(coordinate_values, value);
+ result.Set<float>(coordinate_values, value);
}
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index b265640802..c40b43279f 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -41,7 +41,7 @@ class TextLiteralReader {
public:
// See class comment -- reads a file in its entirety (there must be only one
// literal in the text file path provided).
- static StatusOr<std::unique_ptr<Literal>> ReadPath(absl::string_view path);
+ static StatusOr<Literal> ReadPath(absl::string_view path);
private:
// Ownership of file is transferred.
@@ -49,7 +49,7 @@ class TextLiteralReader {
// Parses a shape string on the first line, followed by lines of values to the
// end of the file.
- StatusOr<std::unique_ptr<Literal>> ReadAllLines();
+ StatusOr<Literal> ReadAllLines();
// Owns the file being read
std::unique_ptr<tensorflow::RandomAccessFile> file_;
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 92f9b4f9f0..1fab4e3a08 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) {
tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents)
.ok());
- std::unique_ptr<Literal> literal =
- TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
+ Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape()));
- EXPECT_EQ(42.5, literal->Get<float>({0, 0, 0}));
- EXPECT_EQ(43.5, literal->Get<float>({0, 0, 1}));
- EXPECT_EQ(44.5, literal->Get<float>({0, 0, 2}));
- EXPECT_EQ(45.5, literal->Get<float>({0, 1, 0}));
- EXPECT_EQ(46.5, literal->Get<float>({0, 1, 1}));
- EXPECT_EQ(47.5, literal->Get<float>({0, 1, 2}));
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape()));
+ EXPECT_EQ(42.5, literal.Get<float>({0, 0, 0}));
+ EXPECT_EQ(43.5, literal.Get<float>({0, 0, 1}));
+ EXPECT_EQ(44.5, literal.Get<float>({0, 0, 2}));
+ EXPECT_EQ(45.5, literal.Get<float>({0, 1, 0}));
+ EXPECT_EQ(46.5, literal.Get<float>({0, 1, 1}));
+ EXPECT_EQ(47.5, literal.Get<float>({0, 1, 2}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 4ea02faffc..5cbaf2fcc1 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) {
});
string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
- ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path));
+ ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path));
string contents;
TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path,
&contents));
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index ba814af476..0c41f227b3 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -121,11 +121,10 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
}
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
- Literal::CreateFromProto(proto));
+ TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer data,
- client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+ client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
scoped_shaped_buffer_arguments.push_back(std::move(data));
}
for (const auto& argument : scoped_shaped_buffer_arguments) {
@@ -161,12 +160,12 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
// --generate_fake_infeed is passed and there exists an infeed operation in
// the HloSnapshot.
absl::optional<tensorflow::thread::ThreadPool> pool;
- std::unique_ptr<Literal> data;
+ Literal data;
if (provide_infeed) {
data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie();
}
auto transfer_infeed = [&data, client]() {
- TF_CHECK_OK(client->TransferToInfeed(*data));
+ TF_CHECK_OK(client->TransferToInfeed(data));
};
if (provide_infeed) {
pool.emplace(tensorflow::Env::Default(), "infeed",
@@ -214,9 +213,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
<< "s: " << module.hlo().hlo_module().name();
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+ TF_ASSIGN_OR_RETURN(Literal result_literal,
client->ShapedBufferToLiteral(*result));
- return std::move(*result_literal);
+ return result_literal;
}
StatusOr<HloSnapshot> ParseInputFile(const string& filename,
@@ -305,11 +304,11 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
result.ToString().c_str());
auto& snapshot = snapshots[i];
if (snapshot.has_result()) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
- literal->ToString().c_str());
+ literal.ToString().c_str());
}
}
}
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index 51909190a3..4f8852f8c1 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -40,8 +40,8 @@ int main(int argc, char **argv) {
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal_proto));
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
- fprintf(stderr, "%s\n", literal->ToString().c_str());
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
}
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 48c8374811..4b5c276bdf 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -36,16 +36,16 @@ int main(int argc, char **argv) {
LOG(QFATAL) << "Usage: " << argv[0] << " <path-to-serialized-literal-text>";
}
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
- LOG(INFO) << "literal: " << *literal;
- fprintf(stderr, "%s\n", literal->ToString().c_str());
- if (literal->shape().element_type() == xla::F32) {
- float min = *std::min_element(literal->data<float>().begin(),
- literal->data<float>().end());
- float max = *std::max_element(literal->data<float>().begin(),
- literal->data<float>().end());
+ LOG(INFO) << "literal: " << literal;
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
+ if (literal.shape().element_type() == xla::F32) {
+ float min = *std::min_element(literal.data<float>().begin(),
+ literal.data<float>().end());
+ float max = *std::max_element(literal.data<float>().begin(),
+ literal.data<float>().end());
fprintf(stderr, "min: %a=%f\n", min, min);
fprintf(stderr, "max: %a=%f\n", max, max);
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index dd329f1181..73b3589dbf 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -351,6 +351,7 @@ message DeviceAssignmentProto {
message LiteralProto {
Shape shape = 1;
repeated bool preds = 2;
+ bytes s8s = 15;
bytes u8s = 3;
repeated int32 s32s = 4;
repeated int64 s64s = 5;
@@ -364,7 +365,7 @@ message LiteralProto {
bytes f16s = 11;
bytes bf16s = 13;
repeated int64 sparse_indices = 14;
- // Next = 15
+ // Next = 16
}
message WindowDimension {