aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD74
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc323
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc12
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc6
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc53
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc129
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc167
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc81
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h44
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc38
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc61
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc253
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h57
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc503
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h203
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling.cc)20
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h (renamed from tensorflow/compiler/xla/service/hlo_scheduling.h)38
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling_test.cc)28
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.h81
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_test.cc142
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc95
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc54
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc88
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h83
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc79
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h25
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h14
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc260
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h34
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc15
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc107
-rw-r--r--tensorflow/compiler/xla/service/service.cc49
-rw-r--r--tensorflow/compiler/xla/service/service.h4
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc19
107 files changed, 2461 insertions, 1922 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e784663ff6..fb80c78f68 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -87,6 +87,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -123,6 +124,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -352,6 +354,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -402,6 +405,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -498,6 +502,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -546,6 +551,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -568,6 +574,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1012,8 +1019,8 @@ cc_library(
":buffer_value_containers",
":heap_simulator",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_proto",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1041,8 +1048,8 @@ tf_cc_test(
":cpu_plugin",
":flatten_call_graph",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1088,8 +1095,8 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dataflow_analysis",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1131,6 +1138,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -1139,6 +1147,37 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_module_group",
+ srcs = ["hlo_module_group.cc"],
+ hdrs = ["hlo_module_group.h"],
+ deps = [
+ ":hlo",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_module_group_test",
+ srcs = ["hlo_module_group_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_module_group",
+ ":hlo_parser",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "hlo_module_group_metadata",
srcs = ["hlo_module_group_metadata.cc"],
hdrs = ["hlo_module_group_metadata.h"],
@@ -1185,9 +1224,9 @@ tf_cc_test(
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1199,13 +1238,14 @@ tf_cc_test(
)
cc_library(
- name = "hlo_scheduling",
- srcs = ["hlo_scheduling.cc"],
- hdrs = ["hlo_scheduling.h"],
+ name = "hlo_memory_scheduler",
+ srcs = ["hlo_memory_scheduler.cc"],
+ hdrs = ["hlo_memory_scheduler.h"],
deps = [
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1219,15 +1259,15 @@ cc_library(
)
tf_cc_test(
- name = "hlo_scheduling_test",
- srcs = ["hlo_scheduling_test.cc"],
+ name = "hlo_memory_scheduler_test",
+ srcs = ["hlo_memory_scheduler_test.cc"],
deps = [
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1259,6 +1299,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1392,6 +1433,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -1708,6 +1750,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -1777,6 +1820,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -1953,6 +1997,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
+ ":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -2236,6 +2281,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2314,6 +2360,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -2394,12 +2441,11 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
- ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -2428,6 +2474,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -2494,6 +2541,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2611,6 +2659,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -2888,6 +2937,7 @@ tf_cc_test(
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 3d18fe3be2..5458159d14 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
return scalar_add_computation_;
}
+ // Tries to fold a kPad in the input or filter into the convolution
+ // instruction's window.
+ StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
+ StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
+
+ // Tries to use a kDot in place of the given convolution.
+ StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot strength reduction on platforms where it causes a slowdown.
bool enable_dot_strength_reduction_;
- // Disable convolution simplification on platforms where it causes a slowdown.
+ // Disable convolution -> dot simplification on platforms where it causes a
+ // slowdown.
bool enable_conv_simplification_;
// Cached computation for adding two scalar F32.
@@ -527,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ HloInstruction::CreateConstant(literal.Clone()));
}
}
@@ -546,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
+ Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -676,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
- HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ HloInstruction::CreateConstant((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@@ -1469,7 +1478,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
}
@@ -1572,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- LiteralUtil::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1607,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -2057,12 +2066,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (pad_literal == reduce_init_literal) {
return true;
}
- auto converted_pad_literal = pad_literal.ConvertToShape(
- reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+ auto converted_pad_literal =
+ pad_literal.ConvertToShape(reduce_init_value->shape());
if (!converted_pad_literal.ok()) {
return false;
}
- return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ return converted_pad_literal.ValueOrDie() == reduce_init_literal;
};
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
@@ -2212,170 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleConvolution(
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
HloInstruction* convolution) {
- auto lhs = convolution->mutable_operand(0);
- auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
- ShapeUtil::IsZeroElementArray(rhs->shape())) {
- return ReplaceWithNewInstruction(
- convolution,
- HloInstruction::CreateBroadcast(
- convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(convolution->shape().element_type())
- .CloneToUnique())),
- {}));
- }
-
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
const auto& window = convolution->window();
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
- // Try to merge padding/dilation of the input with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
- if (lhs->opcode() != HloOpcode::kPad) {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(lhs->operand(1), 0)) {
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
return false;
}
- const auto& padding = lhs->padding_config();
-
- // Can't pad batch or feature dims.
- for (int64 dim :
- {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
}
+ }
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = window;
- for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
- // Edge padding composes with itself in the straightforward way, but
- // composing interior padding is nontrivial, and we cowardly refuse to
- // think about it. If we see interior padding in either the kPad or conv,
- // bail if there's any sort of padding in the other.
- if (p.interior_padding() != 0 &&
- (w.padding_low() != 0 || w.padding_high() != 0 ||
- w.base_dilation() != 1)) {
- return false;
- }
- if (w.base_dilation() != 1 &&
- (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0)) {
- return false;
- }
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- w.set_padding_low(w.padding_low() + p.edge_padding_low());
- w.set_padding_high(w.padding_high() + p.edge_padding_high());
- if (p.interior_padding() != 0) {
- CHECK_EQ(w.base_dilation(), 1);
- w.set_base_dilation(1 + p.interior_padding());
- }
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs->mutable_operand(0), rhs});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
- if (folded_input_pad) {
- return Status::OK();
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
}
- // Try to merge dilation of the filter with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
- if (rhs->opcode() != HloOpcode::kPad) {
- return false;
- }
+ const auto& padding = rhs->padding_config();
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(rhs->operand(1), 0)) {
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- const auto& padding = rhs->padding_config();
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
- // Can't pad or dilate feature dims.
- for (int64 dim : {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
}
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = convolution->window();
- for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
-
- // We can only do this transformation if p adds dilation to the filter --
- // edge padding on the filter is not supported in conv.
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
- return false;
- }
-
- // Nothing to do if the kPad for this dim is entirely a nop.
- if (p.interior_padding() == 0) {
- continue;
- }
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
- // We cowardly refuse to think about how dilation composes with itself;
- // bail if both the kPad and conv have dilation on this dimension.
- if (w.window_dilation() > 1) {
- return false;
- }
- CHECK_EQ(w.window_dilation(), 1);
- w.set_window_dilation(1 + p.interior_padding());
- w.set_size(rhs->operand(0)->shape().dimensions(
- dnums.kernel_spatial_dimensions(dim)));
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
}
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs, rhs->mutable_operand(0)});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- if (folded_filter_pad) {
- return Status::OK();
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
if (!enable_conv_simplification_) {
- return Status::OK();
+ return false;
}
- // HandleConvolution tries to replace a convolution with a DOT instruction.
- //
- // Only add when bitcasts can be used:
- // - if bitcasts are not supported, then reshapes could be used but will
- // end up with another copy.
- // - if bitcasts are supported, the simplifier will be called again with
- // bitcasts_ == true.
- // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
+ // layout-insensitive mode, for fear of adding nontrivial reshapes.
if (!is_layout_sensitive_) {
- return Status::OK();
+ return false;
}
const Shape& input_shape = lhs->shape();
@@ -2388,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// Require the spatial dimensions in the kernel to have a bound of one.
for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
- return Status::OK();
+ return false;
}
}
@@ -2399,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) {
- return Status::OK();
+ return false;
}
// Also, the shapes must align for a rowmajor matmul:
@@ -2425,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dnums.kernel_input_feature_dimension()) <
PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
dnums.kernel_output_feature_dimension()))) {
- return Status::OK();
+ return false;
}
auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
@@ -2467,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
- return Status::OK();
+ return false;
}
auto new_lhs = add_bitcast(new_input_shape, lhs);
@@ -2479,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
convolution->precision_config()));
- return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution) {
+ // Zero-sized input or filter.
+ if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
+ return ReplaceWithNewInstruction(
+ convolution,
+ HloInstruction::CreateBroadcast(
+ convolution->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(convolution->shape().element_type()))),
+ {}));
+ }
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
+ // Try to replace the convolution with a kDot instruction.
+ TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
+ if (replaced_with_dot) {
+ return Status::OK();
+ }
+
+ return Status::OK();
}
bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index a0db4563fb..3fc1ba2427 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector)};
+ Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec281ae68f..30d33e0d35 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
const Shape feature_shape = scale->shape();
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
const Shape feature_shape = scale->shape();
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
@@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
@@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto elements_per_feature_literal =
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
+ elements_per_feature_literal.Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aba0d9bb5b..f7ac8f5482 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
-using BatchNormExpanderTest = HloTestBase;
+using BatchNormExpanderTest = HloVerifiedTestBase;
// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
@@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -126,13 +126,13 @@ ENTRY entry {
epsilon=0.001, feature_index=1, sharding={maximal device=1}
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());
- for (auto* instruction : module->entry_computation()->instructions()) {
+ for (auto* instruction : module().entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 6363a21c3b..5f93740887 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
protected:
+ BFloat16ConversionFoldingTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
@@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_EQ(tuple->operand(0), gte_a);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 933cf873e0..cef0eba14e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
protected:
+ BFloat16NormalizationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
@@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(Normalize(module.get()));
+ EXPECT_FALSE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(dot->shape().element_type(), F32);
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 545a6ecfb1..58f78f8e24 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
continue;
}
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
- TF_ASSIGN_OR_RETURN(
- auto converted_literal,
- hlo->literal().ConvertToShape(hlo->shape(),
- /*round_f32_to_bf16=*/true));
+ TF_ASSIGN_OR_RETURN(auto converted_literal,
+ hlo->literal().ConvertToShape(hlo->shape()));
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 388fd5df99..e032b5c624 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 0f0af57626..65fa951afe 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 5a231c173d..795beb9ff5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -30,11 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
+ Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()})));
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 414bfe7999..17e5090505 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()});
- auto inner_tuple1 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+ Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
+ auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
+ Literal element1 = LiteralUtil::CreateR0<int64>(3);
+ auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- inner_tuple0->shape(), tuple_constant, 0));
+ inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index cc80b74843..34f3f914d5 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,7 +31,7 @@ namespace {
using ::testing::UnorderedElementsAre;
-class CallGraphTest : public HloTestBase {
+class CallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(
@@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(1, call_graph->nodes().size());
EXPECT_TRUE(call_graph->IsFlattened());
@@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
// The called computation is only called from one other computation, but there
@@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(3, call_graph->nodes().size());
@@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
// Verify NearestAncestorsInSameComputation for various instructions in the
@@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Test visitation of only reachable nodes.
{
@@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
auto module = CreateNewModule();
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 5d85a3f173..e6b5665435 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -40,7 +40,7 @@ namespace {
// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
-using CallInlinerTest = HloTestBase;
+using CallInlinerTest = HloVerifiedTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
@@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
@@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, pred, "param"));
+ call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
@@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
@@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index e5a6c28478..96bd2616f5 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(instance.computation, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 0826380f65..0ac4a65ec6 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
- auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+ auto zero = add(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
auto new_filter = add(
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 2368ac8c6a..8cc522a59e 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -122,7 +122,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
@@ -801,6 +801,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -822,6 +823,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -946,6 +948,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -971,6 +974,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 05792795a1..2083f440fd 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -32,7 +32,7 @@ namespace cpu {
using ::testing::ElementsAre;
-class ConvCanonicalizationTest : public HloTestBase {
+class ConvCanonicalizationTest : public HloVerifiedTestBase {
public:
ConvCanonicalizationTest() {
for (int i = 0; i < 2; ++i) {
@@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
const HloInstruction* output_reshape = entry_computation->root_instruction();
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
@@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index e7b6075994..18fc144efe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -77,12 +77,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index 4db7fa446e..c9fb34be1c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
return count;
}
-class CpuCopyInsertionTest : public HloTestBase {
+class CpuCopyInsertionTest : public HloVerifiedTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
@@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*module), 3);
@@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
index 0f463e6de6..be1208fb2d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class CpuHloSupportCheckerTest : public HloTestBase {
+class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 942e2ddd39..55d5925642 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -37,21 +37,20 @@ int main(int argc, char** argv) {
xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
// Transfer parameters.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
// Build computation.
xla::XlaBuilder builder("");
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
@@ -59,17 +58,16 @@ int main(int argc, char** argv) {
// Execute and transfer result of computation.
xla::ExecutionProfile profile;
- xla::StatusOr<std::unique_ptr<xla::Literal>> result =
- client->ExecuteAndTransfer(
- computation,
- /*arguments=*/{param0_data.get(), param1_data.get()},
- /*execution_options=*/nullptr,
- /*execution_profile=*/&profile);
- std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+ xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer(
+ computation,
+ /*arguments=*/{param0_data.get(), param1_data.get()},
+ /*execution_options=*/nullptr,
+ /*execution_profile=*/&profile);
+ xla::Literal actual = result.ConsumeValueOrDie();
LOG(INFO) << absl::StrFormat("computation took %dns",
profile.compute_time_ns());
- LOG(INFO) << actual->ToString();
+ LOG(INFO) << actual.ToString();
return 0;
}
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index 7d8e51f909..1a3d82de95 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace cpu {
namespace {
-class ShapePartitionAssignerTest : public HloTestBase {
+class ShapePartitionAssignerTest : public HloVerifiedTestBase {
protected:
typedef std::vector<int64> Vec;
@@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
expected_partitions);
}
-class ShapePartitionIteratorTest : public HloTestBase {
+class ShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
};
@@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
}
}
-class RandomShapePartitionIteratorTest : public HloTestBase {
+class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
RandomShapePartitionIteratorTest()
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index f11aff0573..c55206eee7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -48,6 +48,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 22721051e5..1deb412064 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@@ -34,7 +34,7 @@ namespace xla {
namespace cpu {
namespace {
-class CpuFusionTest : public HloTestBase {
+class CpuFusionTest : public HloVerifiedTestBase {
protected:
CpuFusionTest() {}
@@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
- Shape vshape = input_literal1->shape();
+ Shape vshape = input_literal1.shape();
auto input1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal1)));
@@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -122,11 +122,10 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
- error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
@@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
<< fusion_instruction2->fused_instructions_computation()->ToString();
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
- *result, error_spec_);
+ result, error_spec_);
}
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// Run fusion.
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
auto fusion1 = result->operand(0);
auto fusion2 = result->operand(1);
@@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The only fusion instruction should be operand 0 of the tuple (formerly
// negate1).
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c35569c661..5cc6d01c0f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
// Send 5 Infeed data of shape F32[3].
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 3 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
}
// Tests two Infeed operations with a total order. The order is enforced by
@@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
+ LiteralUtil::CreateR0<bool>(false)})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
+ LiteralUtil::CreateR0<bool>(false)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 6 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index bb105194f1..7af51db55a 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 1b3be199f6..852f34e06d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -56,9 +56,9 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
- RunTest(hlo_text, {lhs.get(), rhs.get()});
+ Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {&lhs, &rhs});
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f6608241e..5fbd73a536 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(7, module->computation_count());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, module->computation_count());
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// The true and false computations must now be different.
EXPECT_EQ(3, module->computation_count());
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 4ed91ef187..bec02e14f9 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
const auto subliteral = LiteralSlice(literal, index);
- std::unique_ptr<Literal> relayed_out_literal;
+ Literal relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
@@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
- source = relayed_out_literal->untyped_data();
+ source = relayed_out_literal.untyped_data();
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream,
/*size=*/GetByteSizeRequirement(device_subshape), source,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6791e15ee0..64b9683628 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -173,6 +174,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
@@ -370,6 +372,8 @@ cc_library(
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
deps = [
+ ":backend_configs",
+ ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -395,6 +399,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -813,9 +818,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
)
@@ -832,6 +837,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -901,6 +907,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 05448d863d..3a23ac1d63 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
@@ -30,62 +31,32 @@ namespace gpu {
using se::dnn::AlgorithmDesc;
-ConvolutionThunk::ConvolutionThunk(
- CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
- const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
- int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
- : Thunk(Kind::kConvolution, hlo),
- convolution_kind_(convolution_kind),
- input_buffer_(input_buffer),
- filter_buffer_(filter_buffer),
- output_buffer_(output_buffer),
- tuple_result_buffer_(tuple_result_buffer),
- scratch_buffer_(scratch_buffer),
- input_shape_(input_shape),
- filter_shape_(filter_shape),
- output_shape_(output_shape),
- window_(window),
- dim_nums_(dim_nums),
- feature_group_count_(feature_group_count),
- algorithm_(algorithm),
- tensor_ops_enabled_(tensor_ops_enabled) {}
-
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- se::DeviceMemoryBase input_data =
- buffer_allocations.GetDeviceAddress(input_buffer_);
- se::DeviceMemoryBase filter_data =
- buffer_allocations.GetDeviceAddress(filter_buffer_);
- se::DeviceMemoryBase output_data =
- buffer_allocations.GetDeviceAddress(output_buffer_);
+ CudnnConvParams params;
+
+ params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
+ params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- se::dnn::AlgorithmConfig algorithm_config(
- se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(
- convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_,
- feature_group_count_, algorithm_config, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
void* result_ptr = [&] {
- switch (convolution_kind_) {
+ switch (params.kind) {
case CudnnConvKind::kForward:
- return output_data.opaque();
+ return params.output_buf.opaque();
case CudnnConvKind::kBackwardInput:
- return input_data.opaque();
+ return params.input_buf.opaque();
case CudnnConvKind::kBackwardFilter:
- return filter_data.opaque();
+ return params.filter_buf.opaque();
}
}();
void* ptrs[] = {result_ptr, scratch.opaque()};
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 68d67c40c5..d7d1f91fba 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,7 +33,7 @@ limitations under the License.
namespace xla {
namespace gpu {
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
// convolution. It is generated by IrEmitter.
//
// This is thread-compatible.
@@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that
- // we should use the default (i.e. baseline) cudnn algorithm.
- //
// Note that "output" here doesn't refer to the output from running this
// thunk, but rather to the "output" of a hypothetical forward convolution
// that corresponds to this input+filter+output triple. That is, the result
// generated by this thunk is "output" for forward convs, "input" for
// backward-input convs, and "filter" for backward-filter convs.
- //
- // Semantics of null hlo_instruction argument are as in Thunk.
- ConvolutionThunk(CudnnConvKind convolution_kind,
- const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer,
- const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums,
- int64 feature_group_count, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo);
+ ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+ BufferAllocation::Slice input_slice,
+ BufferAllocation::Slice filter_slice,
+ BufferAllocation::Slice output_slice,
+ BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ input_buffer_(std::move(input_slice)),
+ filter_buffer_(std::move(filter_slice)),
+ output_buffer_(std::move(output_slice)),
+ scratch_buffer_(std::move(scratch_slice)),
+ tuple_result_buffer_(std::move(tuple_result_slice)) {}
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- const CudnnConvKind convolution_kind_;
-
- const BufferAllocation::Slice input_buffer_;
- const BufferAllocation::Slice filter_buffer_;
- const BufferAllocation::Slice output_buffer_;
- const BufferAllocation::Slice tuple_result_buffer_;
- const BufferAllocation::Slice scratch_buffer_;
-
- const Shape input_shape_;
- const Shape filter_shape_;
- const Shape output_shape_;
-
- const Window window_;
- const ConvolutionDimensionNumbers dim_nums_;
- int64 feature_group_count_;
- int64 algorithm_;
- bool tensor_ops_enabled_;
+ const HloCustomCallInstruction* cudnn_call_;
+ BufferAllocation::Slice input_buffer_;
+ BufferAllocation::Slice filter_buffer_;
+ BufferAllocation::Slice output_buffer_;
+ BufferAllocation::Slice scratch_buffer_;
+ BufferAllocation::Slice tuple_result_buffer_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 5c2555148a..f528e62b17 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/mutex.h"
@@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- HloInstruction* instr) {
+ const HloCustomCallInstruction* instr) {
+ CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
+
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+
CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
CHECK_EQ(input_shape.element_type(), output_shape.element_type());
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -216,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
allocator = &*se_allocator;
}
- // Allocate space for the input, filter, and output of the convolution. We
- // use a ScratchAllocator for this instead of calling allocator_ directly so
- // that our allocations don't leak.
- ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- if (cross_check_enabled) {
- // Broadcast a constant to the buffer, instead of zeroing the buffer. A
- // non-zero constant is useful for the cross checking, because zero-inputs
- // may not always reveal the bugs.
- const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+ const auto initialize_buffer = [&stream, cross_check_enabled](
+ DeviceMemoryBase buffer) {
+ if (cross_check_enabled) {
+ // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+ // non-zero constant is useful for the cross checking, because zero-inputs
+ // may not always reveal the bugs.
CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
size_t left_over_bytes = buffer.size() % 4;
CHECK_EQ(0, left_over_bytes % 2);
@@ -252,33 +244,46 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
DeviceMemoryBase left_over(
static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
- };
- initialize_f16(input_buf);
- initialize_f16(filter_buf);
- initialize_f16(output_buf);
- } else {
- // Although we don't have evidence this matters, zero out the buffers before
- // autotuning. It's conceivable that using uninitialized memory as the
- // inputs might affect performance if e.g. the inputs contain denormals, and
- // this is easy enough.
- stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size());
- }
+ } else {
+ // Although we don't have evidence this matters, zero out the buffers
+ // before autotuning. It's conceivable that using uninitialized memory as
+ // the inputs might affect performance if e.g. the inputs contain
+ // denormals, and this is easy enough.
+ stream.ThenMemZero(&buffer, buffer.size());
+ }
+ };
+
+ // Allocate space for the input, filter, and output of the convolution. We
+ // use a ScratchAllocator for this instead of calling allocator_ directly so
+ // that our allocations don't leak.
+ ScratchAllocator input_output_allocator(device_ordinal, allocator);
+ TF_ASSIGN_OR_RETURN(params.input_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(input_shape)));
+ TF_ASSIGN_OR_RETURN(params.filter_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+ TF_ASSIGN_OR_RETURN(params.output_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(output_shape)));
+
+ initialize_buffer(params.input_buf);
+ initialize_buffer(params.filter_buf);
+ initialize_buffer(params.output_buf);
DeviceMemoryBase* result_buf = [&] {
- switch (kind) {
+ switch (params.kind) {
case CudnnConvKind::kBackwardFilter:
- return &filter_buf;
+ return &params.filter_buf;
case CudnnConvKind::kBackwardInput:
- return &input_buf;
+ return &params.input_buf;
case CudnnConvKind::kForward:
- return &output_buf;
+ return &params.output_buf;
}
}();
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, dnums, stream_exec_);
+ input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
@@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
for (const AlgorithmDesc& alg :
- GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
+ GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- bool launch_ok =
- RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf,
- filter_buf, output_buf, &scratch_allocator, window, dnums,
- feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
- .ok();
+ params.algorithm = AlgorithmConfig(alg);
+ bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
+ &profile_result)
+ .ok();
if (launch_ok && profile_result.is_valid()) {
const bool crash_on_checking_failure =
@@ -374,34 +377,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- const auto& call_target = instr->custom_call_target();
- const auto& lhs_shape = instr->operand(0)->shape();
- const auto& rhs_shape = instr->operand(1)->shape();
- const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
- if (call_target == kCudnnConvForwardCallTarget) {
- alg_scratch_and_tc =
- PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, instr->window(),
- instr->convolution_dimension_numbers(),
- instr->feature_group_count(), instr);
- } else if (call_target == kCudnnConvBackwardInputCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr->feature_group_count(),
- instr);
- } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
- instr->window(), instr->convolution_dimension_numbers(),
- instr->feature_group_count(), instr);
- } else {
- LOG(FATAL) << "Unknown custom call target for cudnn conv: "
- << instr->ToString();
- }
+ StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
if (!alg_scratch_and_tc.ok()) {
LOG(ERROR) << alg_scratch_and_tc.status();
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 0cb01161b0..f79b113f8f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- HloInstruction* instr);
+ const HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 9bf721ecd2..228379a248 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include <cstdlib>
#include <numeric>
#include <vector>
@@ -59,8 +60,6 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward filter.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -218,13 +217,16 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
// Try to match a backward input pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
-std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
- HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
+MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
- std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward input.
+ // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also
+ // for the backward input convolution, but at least for now with version 7.1.4
+ // it is slower. This needs to be re-evaluated for future cuDNN versions.
+ // Note that we already have the necessary code down below, the only thing to
+ // enable it is to remove the following early return.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -232,51 +234,38 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
-
- // Match the reverse of the filter.
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
- const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
- if (reverse_filter->opcode() == HloOpcode::kReverse) {
- if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
- !std::is_permutation(kernel_spatial_dims.begin(),
- kernel_spatial_dims.end(),
- reverse_filter->dimensions().begin())) {
- VLOG(1)
- << "Backward input convolution should reverse all kernel dimensions.";
- return no_match_result;
- }
- } else if (reverse_filter->IsConstant()) {
- // If the filter is a constant, we're willing to pattern-match to a
- // backwards-input conv, on the theory that
- //
- // a) reversing a constant is free, and
- // b) even if the user specified this filter as reverse(constant), we would
- // long ago have constant-folded away the reverse.
- //
- // If the constant has any other uses, reversing it isn't entirely free,
- // since we'd now have two constants to keep in memory. But hopefully it's
- // free enough.
- //
- // TODO(jlebar): Should we do this even if the filter is not a constant?
- // Reversing a non-constant filter is probably cheaper than padding the
- // input!
-
- // Nothing to do, just fall through.
- } else {
- // Possibly 1x1 filter.
- for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->window().dimensions(i).size() != 1) {
- VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
- << reverse_filter->ToString();
- return no_match_result;
- }
- }
- if (!window_util::HasBaseDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward input convolution.";
- return no_match_result;
- }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
+ "constant filter.";
+ return no_match_result;
}
// Match padding and dilation of the forward convolution.
@@ -401,26 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // OK, it's a match! Canonicalize the conv's filter so that it's a reverse.
- // This simplifies things for our caller, and algebraic-simplifier will later
- // remove any unnecessary reverses.
- if (reverse_filter->opcode() != HloOpcode::kReverse) {
+ // OK, it's a match! Switch the input feature dimension with the output
+ // feature dimension. This is the way cuDNN expects it to be.
+ dnums.set_kernel_input_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+ dnums.set_kernel_output_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+
+ // If we matched against a constant, we need to add a reverse op that can be
+ // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+ // unnecessary reverses.
+ if (reverse_filter->opcode() != HloOpcode::kReverse &&
+ reverse_filter->IsConstant()) {
// Create a double-reverse, which is a nop.
HloComputation* c = conv->parent();
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}
- dnums.set_kernel_input_feature_dimension(
- conv->convolution_dimension_numbers().kernel_output_feature_dimension());
- dnums.set_kernel_output_feature_dimension(
- conv->convolution_dimension_numbers().kernel_input_feature_dimension());
- return std::make_tuple(true, new_window, dnums);
+ // Calculate the 'rhs' that goes into the backward input convolution.
+ HloInstruction* rhs = reverse_filter;
+ // One reverse is subsumed by the cuDNN call.
+ if (rhs->opcode() == HloOpcode::kReverse) {
+ rhs = rhs->mutable_operand(0);
+ }
+ if (conv->feature_group_count() == 1) {
+ return std::make_tuple(true, new_window, dnums, rhs);
+ }
+
+ // Handle grouped convolutions. Because we swapped the input feature dimension
+ // with the output feature dimension, we need to also reshape the kernel so
+ // that the 'feature_group_count' parameter still makes sense. The
+ // 'feature_group_count' parameter essentially specifies how often the
+ // 'kernel_input_feature_dimension' is repeated. So when we swap these
+ // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+ // 'feature_group_count' and multiply the new
+ // 'kernel_output_feature_dimension' by 'feature_group_count'.
+ Shape new_shape = rhs->shape();
+ int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
+
+ // In the backward convolution case, the spatial dimensions become the
+ // feature dimensions, and we are guaranteed that the spatial dimensions are
+ // adjacent.
+ CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
+ int64 input_features = new_shape.dimensions(input_feature_dimension);
+ int64 output_features = new_shape.dimensions(output_feature_dimension);
+ new_shape.set_dimensions(input_feature_dimension,
+ input_features / conv->feature_group_count());
+ new_shape.set_dimensions(output_feature_dimension,
+ output_features * conv->feature_group_count());
+ HloComputation* c = conv->parent();
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+ return std::make_tuple(true, new_window, dnums, rhs);
}
// Tries to rewrite a single convolution into a call to cudnn.
@@ -431,6 +458,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
+ HloInstruction* rhs;
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
@@ -439,13 +467,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
window, dnums, conv->feature_group_count());
}
- std::tie(match, window, dnums) = MatchBackwardInput(conv);
+ std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- // Backward input conv subsumes the conv plus the reverse in operand 1.
- HloInstruction* reverse = conv->mutable_operand(1);
- CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
- HloInstruction* rhs = reverse->mutable_operand(0);
-
return CreateCudnnConvBackwardInput(conv->shape(),
conv->mutable_operand(0), rhs, window,
dnums, conv->feature_group_count());
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index bda8ebe579..d237f8930b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
Array4D<float> constant_arr(4, 4, 2, 2);
constant_arr.FillIota(0);
string constant_str =
- LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+ LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
ParseAndVerifyModule(absl::StrFormat(R"(
HloModule test
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 05125e9d1f..2a86ac265e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator {
};
template <typename T>
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<T> input_buf,
- DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- AlgorithmConfig algorithm, Stream* stream,
- ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ CudnnConvKind kind = params.kind;
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+ DeviceMemory<T> input_buf(params.input_buf);
+ DeviceMemory<T> filter_buf(params.filter_buf);
+ DeviceMemory<T> output_buf(params.output_buf);
+ const Window& window = *params.window;
+ const ConvolutionDimensionNumbers& dnums = *params.dnums;
+ int64 feature_group_count = params.feature_group_count;
+ AlgorithmConfig algorithm = params.algorithm;
+
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) {
}
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
- output_buf, &scratch_allocator, window, dnums, feature_group_count,
- algorithm, stream, profile_result);
+ return RunCudnnConvolution(params, &scratch_allocator, stream,
+ profile_result);
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = params.output_shape->element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+ stream, profile_result);
case F32:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+ profile_result);
case F64:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+ profile_result);
default:
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index a1b4fc71d0..381aa37a1b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -47,6 +47,20 @@ enum class CudnnConvKind {
kBackwardFilter, // input + output => filter
};
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// Converts a CudnnConvKind value to a string.
string CudnnConvKindToString(CudnnConvKind kind);
@@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
// Note that depending on the value of CudnnConvKind, the result of this call
// may be written into input_buf, filter_buf, or output_buf!
//
-// At the moment we only support cudnn convolutions over float and half, and
-// convolution with half data type is implemented with cudnn PSEUDO_HALF
-// configuration, that is, the input values are half and the internal
-// computation type is float.
+// At the moment convolution with half data type is implemented with cudnn
+// PSEUDO_HALF configuration, that is, the input values are half and the
+// internal computation type is float.
//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
@@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index ea9376e101..02a0d028c1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 59ade96f7d..b857fa775a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -24,14 +24,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class GpuHloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloVerifiedTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
index 0a4089df4c..27a4d0b601 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class GpuHloSupportCheckerTest : public HloTestBase {
+class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 20d523abe0..22f43bc08b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params) {
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ custom_call->backend_config<CudnnConvBackendConfig>());
+ const auto& target = custom_call->custom_call_target();
+ const auto& lhs_shape = custom_call->operand(0)->shape();
+ const auto& rhs_shape = custom_call->operand(1)->shape();
+ const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+
+ params->window = &custom_call->window();
+ params->dnums = &custom_call->convolution_dimension_numbers();
+ params->feature_group_count = custom_call->feature_group_count();
+ params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params->kind = CudnnConvKind::kForward;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &conv_result_shape;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params->kind = CudnnConvKind::kBackwardInput;
+ params->input_shape = &conv_result_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &lhs_shape;
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params->kind = CudnnConvKind::kBackwardFilter;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &conv_result_shape;
+ params->output_shape = &rhs_shape;
+ } else {
+ LOG(FATAL) << "Unexpected custom call target: "
+ << custom_call->custom_call_target();
+ }
+ return Status::OK();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 59c65fc268..09c455cc1e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,9 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
// don't belong in "ir_emission_utils".
@@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
+// Populates params using conv, which must be a custom-call to a cudnn
+// convolution. Does not modify any buffers in the params.
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f91cc00d71..b669881026 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
const auto& target = custom_call->custom_call_target();
- std::unique_ptr<ConvolutionThunk> thunk;
+ BufferAllocation::Slice input_slice, filter_slice, output_slice;
+
if (target == kCudnnConvForwardCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kForward,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/conv_result_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = rhs_slice;
+ output_slice = conv_result_slice;
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardInput,
- /*input_buffer=*/conv_result_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/lhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/lhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = conv_result_slice;
+ filter_slice = rhs_slice;
+ output_slice = lhs_slice;
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardFilter,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/conv_result_slice,
- /*output_buffer=*/rhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape,
- /*output_shape=*/rhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = conv_result_slice;
+ output_slice = rhs_slice;
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
}
- thunk_sequence_->emplace_back(std::move(thunk));
+ thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
+ Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
+ output_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index f6325b3368..dfdcf1875d 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
- // CudnnConvolutionRewriter may add instructions of the form
- // reverse(constant), which it expects will be simplified by constant
- // folding.
- pipeline.AddPass<HloConstantFolding>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// pairs that TupleSimplifier fixes.
pipeline.AddPass<TupleSimplifier>();
}
+ // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add
+ // instructions which can be simplified by constant folding.
+ pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index fa84d77223..b0061fa655 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
static constexpr int64 kDesiredNumFeaturesFactor = 8;
@@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
HloComputation* comp = instr->parent();
const Shape& shape = instr->shape();
- auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+ auto* zero = comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 9d85d746d8..2a6415d0b6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
conv_window.dimensions(i).base_dilation() - 1);
}
PrimitiveType element_type = input->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloComputation* computation = kernel->parent();
PrimitiveType element_type = kernel->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(input->shape().element_type()))));
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(input->shape().element_type())));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 8f0dedfa40..c4f43cc9a6 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -21,14 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class StreamAssignmentTest : public HloTestBase {
+class StreamAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index 4550f36fdf..780539c164 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
TEST_F(GpuCopyTest, UseMemcpy) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
builder.AddInstruction(HloInstruction::CreateUnary(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index 9072b30317..f8120a5fa0 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
@@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
TEST_F(InfeedTest, LargeInfeed) {
Array4D<float> array(80, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
+ TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
Array4D<float> array(40, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
- LiteralUtil::CreateR0<int32>(5).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR4FromArray4D<float>(array),
+ LiteralUtil::CreateR0<int32>(5)}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 40183de96e..9a61f8ac5a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -26,9 +26,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::testing::Eq;
-using ::testing::HasSubstr;
-
class WhileTransformerTest : public HloTestBase {
protected:
WhileTransformerTest()
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 00a25db467..957c4a6891 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
-class MinimumMemoryForSequenceTest : public HloTestBase {};
+class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
auto module = CreateNewModule();
@@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- HloSchedule schedule(module.get());
+ HloSchedule schedule(module);
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_data, cond_lt});
schedule.set_sequence(body_computation, {body_param});
@@ -233,7 +233,7 @@ class HeapSimulatorTracker {
HeapSimulator::Result result_;
};
-class HeapSimulatorTest : public HloTestBase {
+class HeapSimulatorTest : public HloVerifiedTestBase {
protected:
HeapSimulatorTest() {}
~HeapSimulatorTest() override {}
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 93ec2c9438..b19ec12638 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -309,6 +309,13 @@ message HeapSimulatorTrace {
bool whole_module_simulation = 2;
}
+// An abstraction representing a set of HLO module built to run concurrently
+// across different devices.
+message HloModuleGroupProto {
+ string name = 1;
+ repeated HloModuleProto hlo_modules = 2;
+}
+
// Serialization of BufferAssignment.
message BufferAssignmentProto {
// Alias represents a source LogicalBuffer, and the buffer location that
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 233d2199d1..8c6903d766 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ auto computation = absl::WrapUnique(
+ new HloComputation(proto.name(), parameter_count, &instructions, root,
+ /*fusion_instruction=*/nullptr));
+ computation->unique_id_ = proto.id();
+ return std::move(computation);
}
void HloComputation::FuseInstructionsInto(
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 8a45939c61..f837816cea 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
- std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+ Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
- if (result == nullptr) {
+ if (!evaluator->TryEvaluate(instruction, &result)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 07cd1efc12..3e0def5d26 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using HloConstantFoldingTest = HloTestBase;
+using HloConstantFoldingTest = HloVerifiedTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
@@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->Literal::CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+ matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}
@@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"(
})";
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
+ ParseAndVerifyModule(kConstantFoldReduce);
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_TRUE(result);
- EXPECT_EQ(6, module->entry_computation()
+ EXPECT_EQ(6, module()
+ .entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
- HloInstruction* add = module->computations().begin()->root_instruction();
+ ParseAndVerifyModule(kConstantFoldReduce);
+ HloInstruction* add = module().computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_FALSE(result);
- EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+ EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index a3fcc0fefa..b76c50bb5b 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(absl::make_unique<Literal>(
- LiteralUtil::Zero(operand->shape().element_type()))));
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(operand->shape().element_type())));
return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index eb6affadc8..e07a196d11 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
entry_computation->set_root_instruction(first_1_dims_collapsed);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module,
{LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<int32>(
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
entry_computation->set_root_instruction(zero_padded_param);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index e09d5868f2..9b18b0284f 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index d0d955fea8..06b6d5b559 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -54,9 +54,8 @@ namespace xla {
namespace {
template <typename OperandT>
-StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
}));
@@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
template <>
-StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
@@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
@@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
- .CloneToUnique();
+ .Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloModule& module, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(module, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
@@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
}
TF_RETURN_IF_ERROR(computation.Accept(this));
- return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
+ return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloComputation& computation, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(computation, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
@@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
- evaluated_[operand] = input_literal->CloneToUnique();
+ evaluated_[operand] = input_literal->Clone();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal : arg_literals) {
+ arg_literal_ptrs.push_back(&literal);
+ }
+ return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction) {
+StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
@@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
}
-std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
- HloInstruction* instruction) {
+bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
+ CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
- return nullptr;
+ return false;
}
- return result_or.ConsumeValueOrDie();
+ *result = result_or.ConsumeValueOrDie();
+ return true;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
+StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
@@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
- HloInstruction::CreateConstant(it->second->CloneToUnique()));
+ HloInstruction::CreateConstant(it->second->Clone()));
}
}
@@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
- HloInstruction::CreateConstant(operand.CloneToUnique());
+ HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
@@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
- evaluated_[parameter] = input_literal->CloneToUnique();
+ evaluated_[parameter] = input_literal->Clone();
return Status::OK();
}
@@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
@@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64 index_vector_dim, const Literal& start_indices,
- std::unique_ptr<Literal>* reshaped_start_indices) {
+ Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
@@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
start_indices.Reshape(new_shape));
- return std::cref(**reshaped_start_indices);
+ return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
- std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+ Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_start_indices;
+ Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
@@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
- result->CopyElementFrom(operand, input_index, output_index));
+ result.CopyElementFrom(operand, input_index, output_index));
return true;
};
@@ -940,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
// Checks that operand's dimensions are the same as the broadcast's
// dimensions along the dimensions to be broadcasted.
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand.shape().dimensions(i));
+ auto operand_dim_size = operand.shape().dimensions(i);
+ auto broadcast_dim_size =
+ broadcast->shape().dimensions(broadcast->dimensions(i));
+ TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
+ "Operand dimension %d is broadcast to output dimension %d, but the "
+ "sizes of these two dims do not match (%d vs %d): %s",
+ i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
+ broadcast->ToString());
}
TF_ASSIGN_OR_RETURN(
@@ -971,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = absl::make_unique<Literal>(
- ShapeUtil::GetTupleElementShape(operand->shape(), index));
- return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
- /*dest_shape_index=*/{},
- /*src_shape_index=*/{index});
+ evaluated_[get_tuple_element] =
+ Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
+ return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
+ /*dest_shape_index=*/{},
+ /*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
-
- auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
- evaluated_[copy] = std::move(result);
+ evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
@@ -998,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
.ConsumeValueOrDie();
@@ -1030,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator
.Evaluate<const Literal*>(*readded_computation, arg_literals)
.ConsumeValueOrDie();
@@ -1050,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
auto* false_computation = conditional->false_computation();
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result;
+ Literal result;
if (pred.Get<bool>({})) {
result = embedded_evaluator
.Evaluate<const Literal*>(*true_computation,
@@ -1075,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
- evaluated_[select] = on_true.CloneToUnique();
+ evaluated_[select] = on_true.Clone();
} else {
- evaluated_[select] = on_false.CloneToUnique();
+ evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
@@ -1091,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
- evaluated_[tuple_select] = on_true.CloneToUnique();
+ evaluated_[tuple_select] = on_true.Clone();
} else {
- evaluated_[tuple_select] = on_false.CloneToUnique();
+ evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
@@ -1102,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
- auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
+ auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
@@ -1112,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
while_hlo->name(), max_loop_iterations_);
}
- TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
- *cond_comp, {lcv.get()}));
- keep_going = cond_val->GetFirstElement<bool>();
+ TF_ASSIGN_OR_RETURN(auto cond_val,
+ cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
+ keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
- *body_comp, {lcv.get()}));
- VLOG(3) << "Loop iteration result: " << body_val->ToString();
+ *body_comp, {&lcv}));
+ VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
@@ -1133,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
// hoops to make this work.
namespace {
template <typename KeyType, typename ValueType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@@ -1173,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
- auto result_values_literal =
- absl::make_unique<Literal>(values_literal.shape());
- result_values_literal->PopulateR1(
+ Literal result_keys_literal(keys_literal.shape());
+ result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal result_values_literal(values_literal.shape());
+ result_values_literal.PopulateR1(
absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
- std::unique_ptr<Literal> result_tuple;
+ Literal result_tuple;
if (rank == 1) {
auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple = LiteralUtil::MakeTuple(
- {result_pair.first.get(), result_pair.second.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
- auto values_result_literal =
- absl::make_unique<Literal>(values_literal.shape());
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
+ .Reshape({r1_length}));
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
values_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+ .Reshape({r1_length}));
+ auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first->Reshape({1, r1_length}));
+ r1_result_pair.first.Reshape({1, r1_length}));
TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
- *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
- *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+ r1_result_pair.second.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
}
- result_tuple = LiteralUtil::MakeTuple(
- {keys_result_literal.get(), values_result_literal.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
}
- VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
template <typename KeyType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@@ -1242,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
}
}
-StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
- const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(0)->shape().element_type()) {
case F32:
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@@ -1308,33 +1339,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) {
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
VLOG(2) << "Finished visiting " << hlo->ToString()
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
+ // Out of convenience the literal may have been produced with a different
+ // layout. Relayout as indicated by the HLO instruction.
+ if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
+ hlo->shape())) {
+ evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
+ }
return Status::OK();
}
// Explicit instantiation of templatized Evaluate* methods.
//
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloModule& module, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
-
-template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
- const Literal*>(const HloComputation& computation,
- absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
+
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloComputation& computation,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- HloInstruction* instruction,
- absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 72252bafc7..21e676d671 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Precondition: The indices of arg_literals correspond to the parameter
// numbers of the HLO parameters in the computation. See comment below for an
// example.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloModule& module,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
// 1 in this computation. The input literals array will then have its first
// literal map to Parameter0 and the second map to Parameter1.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloComputation& computation,
- absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
- StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction);
- // Same as Evaluate, except returning nullptr on error.
- std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+ // Same as Evaluate, except returning false on error and accepts an output
+ // pointer.
+ bool TryEvaluate(HloInstruction* instruction, Literal* result);
// Evaluates a single HLO instruction, substituting the given literals for
// some of the instruction's operands.
//
// For example, given instruction = op(A, B, C) and the map
// {A = x, C = y}, this evaluates op(x, B, y).
- StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
+ StatusOr<Literal> EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
- HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+ StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
+ const Literal& lhs,
+ const Literal& rhs);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
- HloOpcode opcode, const Literal& operand);
+ StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
+ const Literal& operand);
- StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers,
- const PrecisionConfig& precision_config, const Literal& lhs,
- const Literal& rhs);
+ StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
+ const Literal& lhs, const Literal& rhs);
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
- return *(it->second);
+ return it->second;
}
// Tracks the HLO instruction and its evaluated literal result.
@@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// that are no longer a parent for any other subsequent instruction in
// post-orderring.
// Must be cleared for each evaluation.
- tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
- evaluated_;
+ // Storing Literal in place require the container to have pointer stability so
+ // we cannot use FlatMap any more.
+ std::unordered_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
- static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ static StatusOr<Literal> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
@@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
ShapeUtil::HumanString(operand->shape()));
}
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 102ebb24ab..01e88566a5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
evaluator_ = absl::make_unique<HloEvaluator>();
}
- std::unique_ptr<Literal> Evaluate(
- absl::Span<const Literal* const> arg_literals = {}) {
+ Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<HloEvaluator> evaluator_;
- void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> input, float aabs = 0) {
+ void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
+ float aabs = 0) {
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
- b.AddInstruction(
- HloInstruction::CreateUnary(expected->shape(), opcode, c1));
+ b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto element_type = expected->shape().element_type();
+ auto element_type = expected.shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
}
- void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> lhs,
- std::unique_ptr<Literal> rhs) {
+ void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
+ Literal rhs) {
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
b.AddInstruction(
- HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
+ HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
bool use_bfloat16_;
@@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- Shape shape = low->shape();
+ Shape shape = low.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
auto high = LiteralUtil::CreateR0<float>(1.f);
- Shape shape = value->shape();
+ Shape shape = value.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- Shape shape = on_true->shape();
+ Shape shape = on_true.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
auto c2 =
@@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
- std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
+ std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
lhs_instruction, param_rhs2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate(args);
+ Literal result = Evaluate(args);
auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies Reshape operation is correctly evaluated.
@@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
@@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
- result->EachCell<NativeT>(
- [&](absl::Span<const int64> indices, NativeT value) {
- std::vector<int64> rindexes = Permute(permutation, indices);
- EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
- });
+ result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
+ });
}
// Verifies Broadcast operation is correctly evaluated.
@@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction, {1, 2}));
+ output_literal.shape(), literal_instruction, {1, 2}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloInstruction::CreateConstant(std::move(input_literal)));
// Broadcast dimension should be empty in the case of scalars.
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction,
+ output_literal.shape(), literal_instruction,
/*broadcast_dimensions=*/{}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int64>(
{{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<int64>({100, 200});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
- ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
- ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
PaddingConfig CreatePaddingConfig(
@@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
shape, operand_instruction, padding_value_instruction, padding_config));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
@@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
@@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected_array = Array2D<float>({
@@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = Array2D<float>({
{22.f, 28.f},
@@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 4, 4);
// clang-format off
@@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 7, 7);
expected_array.FillWithYX(Array2D<float>({
@@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 8, 8);
expected_array.FillWithYX(Array2D<float>({
@@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest,
@@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 9, 3);
expected_array.FillWithYX(Array2D<float>({
@@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest,
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
@@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
std::iota(input_elems.begin(), input_elems.end(), -7);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
std::iota(filter_elems.begin(), filter_elems.end(), -31);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
@@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
/*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 1, 8);
expected_array.FillWithYX(
Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
module().AddEntryComputation(b.Build());
HloEvaluator hlo_eval;
- std::unique_ptr<Literal> result =
- hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+ Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
}
// Reducing many numbers should be fast because it doesn't create
@@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({6, 18});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
@@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
/*strides=*/{2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
@@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
shape, operand, update, start_indices));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto result_inner_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
- auto expected = LiteralUtil::MakeTuple({
- result_inner_literal.get(),
- result_inner_literal.get(),
- });
+ auto expected =
+ LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Reverse) {
@@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected = LiteralUtil::CreateR4FromArray4D<float>({
@@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
+ Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, &param0_literal}, {square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
- auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
+ auto result =
+ evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1906,12 +1901,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1930,12 +1925,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1954,14 +1949,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>(
+ LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1980,15 +1974,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest,
@@ -2008,15 +2001,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -2035,12 +2027,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -2059,13 +2050,12 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> start_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2084,11 +2074,10 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2108,12 +2097,12 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> start_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), start_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2138,15 +2127,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
@@ -2171,15 +2158,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
@@ -2205,15 +2191,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
@@ -2239,15 +2223,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) {
@@ -2273,17 +2255,15 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(
+ LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()}),
- ErrorSpec{0.1, 0.01}));
+ Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01}));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
@@ -2309,15 +2289,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
@@ -2343,15 +2321,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
@@ -2376,21 +2353,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
{{-40, 40}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest,
@@ -2416,21 +2390,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
@@ -2455,16 +2426,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
@@ -2489,17 +2458,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
@@ -2524,13 +2490,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *operand,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ operand, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
@@ -2557,16 +2521,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<int32>({10, 61, 32});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -2603,11 +2564,29 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>(
+ Literal arg = LiteralUtil::CreateR1<bfloat16>(
{bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()})));
+ Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
+}
+
+TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) {
+ // Regression test for b/114735354.
+ const string hlo_text = R"(
+HloModule SliceWithDifferentLayout
+
+ENTRY main {
+ arg = f32[2,2,2]{0,1,2} parameter(0)
+ ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ Literal arg = LiteralUtil::CreateR3WithLayout<float>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ LayoutUtil::MakeLayout({0, 1, 2}));
+ Literal actual = Evaluate({&arg});
+ EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
}
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 63303aef1e..8fb17a0033 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = absl::make_unique<Literal>(dot->shape());
+ Literal result(dot->shape());
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ Literal result(pad->shape());
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
@@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
+ result.Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = absl::make_unique<Literal>(map->shape());
+ Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
+ Literal computed_result =
+ embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
- return computed_result->Get<ReturnT>({});
+ return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ Literal result_literal(keys_literal.shape());
+ result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result = sort_r1(*r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
- *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ .Reshape({r1_length}));
+ auto r1_result = sort_r1(r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
- results[i] = absl::make_unique<Literal>(result_shape);
+ results[i] = Literal(result_shape);
}
Status eval_status;
@@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
for (int64 input = 0; input < num_args; ++input) {
- TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
[&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
@@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
// Evaluate computation with specified literal operands.
- absl::InlinedVector<std::unique_ptr<Literal>, 1>
- embedded_operands;
+ absl::InlinedVector<Literal, 1> embedded_operands;
for (ReturnT value : result_values) {
embedded_operands.push_back(
LiteralUtil::CreateR0<ReturnT>(value));
@@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_operands.size());
std::transform(embedded_operands.begin(), embedded_operands.end(),
embedded_operands_ptrs.begin(),
- [](const std::unique_ptr<Literal>& ptr) {
- return ptr.get();
- });
+ [](Literal& literal) { return &literal; });
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ TF_ASSIGN_OR_RETURN(Literal computed_result,
embedded_evaluator.Evaluate<const Literal*>(
*function, embedded_operands_ptrs));
// Clear visit states so that we can use the evaluator again on
@@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
// Assign computed result to result_val.
if (!has_tuple_output) {
- result_values[0] = computed_result->Get<ReturnT>({});
+ result_values[0] = computed_result.Get<ReturnT>({});
} else {
for (int64 i = 0; i < num_args; ++i) {
- result_values[i] = computed_result->Get<ReturnT>(
+ result_values[i] = computed_result.Get<ReturnT>(
/*multi_index=*/{}, /*shape_index=*/{i});
}
}
@@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!has_tuple_output) {
parent_->evaluated_[reduce] = std::move(results[0]);
} else {
- auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ Literal tuple_result(reduce->shape());
for (int64 i = 0; i < num_args; ++i) {
- TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
parent_->evaluated_[reduce] = std::move(tuple_result);
}
@@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = absl::make_unique<Literal>(select_and_scatter->shape());
+ Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
@@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
- curr_val_literal->Set({}, curr_val);
- selected_val_literal->Set({}, *selected_val);
- std::unique_ptr<Literal> computed_result =
+ curr_val_literal.Set({}, curr_val);
+ selected_val_literal.Set({}, *selected_val);
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *select,
- {selected_val_literal.get(), curr_val_literal.get()})
+ *select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
+ bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- source_literal_scatter->Set({}, source);
- scattered_literal->Set({}, scattered);
- std::unique_ptr<Literal> computed_result =
+ auto scattered = result.Get<ReturnT>(operand_index);
+ source_literal_scatter.Set({}, source);
+ scattered_literal.Set({}, scattered);
+ Literal computed_result =
embedded_evaluator
- .Evaluate<const Literal*>(*scatter,
- {source_literal_scatter.get(),
- scattered_literal.get()})
+ .Evaluate<const Literal*>(
+ *scatter,
+ {&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce_window->shape());
+ Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
+ *function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
- result_val = computed_result->Get<ReturnT>({});
+ result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
- std::unique_ptr<Literal>* reshaped_indices) {
+ Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
- return std::cref(**reshaped_indices);
+ return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
- std::unique_ptr<Literal> reshaped_scatter_indices;
+ Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
- std::unique_ptr<Literal> result = operand.CloneToUnique();
+ Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
[&](absl::Span<const int64> update_window_index,
@@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
- LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+ LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
- std::unique_ptr<Literal> updated_result =
+ Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
- {result_value_literal.get(), update_value_literal.get()})
+ {&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
- result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+ result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
@@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (ShapeUtil::Rank(iota->shape()) > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
- result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
parent_->evaluated_[iota] = std::move(result);
@@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
+ StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+ const Literal& start_indices_literal,
+ const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
+ StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+ const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
+ const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
- result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+ result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
+ result.Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
return true;
};
@@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = absl::make_unique<Literal>(shape);
+ Literal result(shape);
TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 0345a2a5f8..287ba84b3b 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -123,6 +123,10 @@ class NodeFilter {
// We arbitrarily set this as the boundary between "large" and "small"
// instructions.
bool IsSmall(const HloInstruction* instr) {
+ if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
+ ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
+ return true;
+ }
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
}
@@ -465,9 +469,8 @@ stylesheet=<
string graph_label =
StrCat(label_, "<br/>Computation ", computation_->name());
if (computation_->IsFusionComputation()) {
- StrAppend(&graph_label,
- StrCat(" (in fusion instruction ",
- computation_->FusionInstruction()->name(), ")"));
+ StrAppend(&graph_label, " (in fusion instruction ",
+ computation_->FusionInstruction()->name(), ")");
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 25ae344ea5..e905f2983a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
- instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+ instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
@@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -527,7 +528,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
- std::unique_ptr<Literal> literal) {
+ Literal literal) {
return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
@@ -2096,7 +2097,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
- if (!control_predecessors_.empty()) {
+ if (options.print_control_dependencies() && !control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
StrJoin(control_predecessors_, ", ",
[&](string* out, HloInstruction* pre) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5581c17c2d..4f6cac1396 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -82,6 +82,7 @@ class HloPrintOptions {
print_operand_shape_(true),
print_program_shape_(true),
print_percent_(true),
+ print_control_dependencies_(true),
canonicalize_instruction_names_(false),
indent_amount_(0),
is_in_nested_computation_(false) {}
@@ -94,7 +95,8 @@ class HloPrintOptions {
.set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
- .set_print_percent(false);
+ .set_print_percent(false)
+ .set_print_control_dependencies(false);
}
// Options to produce the canonical string representing an isomorphic
@@ -108,6 +110,7 @@ class HloPrintOptions {
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
+ .set_print_control_dependencies(false)
.set_canonicalize_instruction_names(true);
}
@@ -153,6 +156,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, control dependencies will be printed.
+ HloPrintOptions& set_print_control_dependencies(bool value) {
+ print_control_dependencies_ = value;
+ return *this;
+ }
+
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
@@ -190,6 +199,9 @@ class HloPrintOptions {
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
+ bool print_control_dependencies() const {
+ return print_control_dependencies_;
+ }
bool canonicalize_instruction_names() const {
return canonicalize_instruction_names_;
}
@@ -205,6 +217,7 @@ class HloPrintOptions {
bool print_operand_shape_;
bool print_program_shape_;
bool print_percent_;
+ bool print_control_dependencies_;
bool canonicalize_instruction_names_;
int indent_amount_;
bool is_in_nested_computation_;
@@ -346,8 +359,7 @@ class HloInstruction {
const string& name);
// Creates a literal constant instruction.
- static std::unique_ptr<HloInstruction> CreateConstant(
- std::unique_ptr<Literal> literal);
+ static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
// Creates an Iota instruction.
static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index fb7345a2ad..e92882c22a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
-HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
- : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+HloConstantInstruction::HloConstantInstruction(Literal literal)
+ : HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- if (literal_ != nullptr) {
+ if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
return proto;
@@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
if (!mutable_array_subshape->has_layout() ||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
- literal_ = literal_->Relayout(new_layout, shape_index);
+ *literal_ = literal_->Relayout(new_layout, shape_index);
*mutable_array_subshape->mutable_layout() = new_layout;
}
}
@@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
+ CHECK(literal_.has_value());
+ return absl::make_unique<HloConstantInstruction>(literal_->Clone());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if (literal_ != nullptr &&
+ if (literal_.has_value() &&
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
@@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstructionProto HloTraceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_literal() = literal_->ToProto();
+ *proto.mutable_literal() = literal_.ToProto();
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index c3a7801164..2d7bc83855 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
- explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ explicit HloConstantInstruction(Literal literal);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns whether there is literal associated with this instruction.
- bool HasLiteral() const { return literal_ != nullptr; }
+ bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
- string TracingTag() const { return literal_->GetR1U8AsString(); }
+ string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ Literal literal_;
};
class HloFusionInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 9bfb0af96c..c7ec88d450 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <map>
#include <queue>
@@ -582,4 +582,22 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
size_function, nullptr, empty_map);
}
+HloMemoryScheduler::HloMemoryScheduler(
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm)
+ : size_function_(size_function), algorithm_(algorithm) {}
+
+StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, size_function_, algorithm_));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ return true;
+}
+
+StatusOr<bool> HloDescheduler::Run(HloModule* module) {
+ bool changed = module->has_schedule();
+ module->clear_schedule();
+ return changed;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 54e32340ba..5e02868eba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -86,6 +87,37 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
+// A pass which schedules the HLO instructions in a module. The HloModule's
+// schedule field is set to the resulting HloSchedule using
+// HloModule::set_schedule.
+class HloMemoryScheduler : public HloPassInterface {
+ public:
+ // size_function is the function returning the number of bytes required for a
+ // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
+ // specified, then DefaultMemoryScheduler is used.
+ HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
+ ~HloMemoryScheduler() override = default;
+ absl::string_view name() const override { return "hlo-memory-scheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ LogicalBuffer::SizeFunction size_function_;
+ MemorySchedulerAlgorithm algorithm_;
+};
+
+// A trivial pass which clears the schedule currently set on the
+// HloModule. After this pass runs HloModudle::has_schedule will return false.
+class HloDescheduler : public HloPassInterface {
+ public:
+ HloDescheduler() = default;
+ ~HloDescheduler() override = default;
+ absl::string_view name() const override { return "hlo-descheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 6afe51997e..1b9e9bfc77 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <memory>
#include <string>
@@ -67,22 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- HloSchedule schedule,
- ScheduleModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
+ HloMemoryScheduler scheduler([](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ });
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+
// Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence =
- schedule.sequence(module->entry_computation()).instructions();
+ module->schedule().sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
EXPECT_EQ(param, sequence.front());
EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(schedule);
+ SequentialHloOrdering ordering(module->schedule());
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
+
+ // Clear the schedule using the descheduling pass.
+ HloDescheduler descheduler;
+ EXPECT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
+ descheduler.Run(module.get()));
+ EXPECT_TRUE(descheduler_changed);
+ EXPECT_FALSE(module->has_schedule());
}
TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cfe906d9c5..b3949f3a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names) {
+ bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
- if (uniquify_names) {
+ if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
+
+ // Pick unique IDs for each instruction.
+ for (auto* instruction : computation->instructions()) {
+ instruction->SetUniqueId(NewUniqueInstructionId());
+ }
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
- // for computations and instructions created later.
+ // for computations and instructions created later. Also, set the
+ // next_unique_id_ to the one greater than the max unique id of any
+ // instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
+ next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+ }
+ if (next_unique_id_ < computation->unique_id() + 1) {
+ next_unique_id_ = computation->unique_id() + 1;
}
}
- // Pick unique IDs for each instruction.
- for (auto* instruction : computation->instructions()) {
- instruction->SetUniqueId(NewUniqueInstructionId());
- }
- // Set unique id to this computation.
- CHECK_NE(computation->root_instruction()->unique_id(), -1)
- << "Root has no valid id: " << computation->ToString();
- computation->SetUniqueId(computation->root_instruction()->unique_id());
-
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ VLOG(2) << "CreateFromProto()";
+ XLA_VLOG_LINES(2, proto.DebugString());
+
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
- /*uniquify_names=*/false);
+ /*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
- // Because we didn't uniquify the names, double-check that the instruction and
- // computation names are unique from the proto.
+ // Because we didn't uniquify the names or the ids, double-check that the
+ // instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
+ tensorflow::gtl::FlatSet<int> computation_ids;
+ tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
+
+ TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+ << "Computation id is not unique: " << computation->unique_id();
+ computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
+
+ TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+ << "Instruction id is not unique: " << instruction->unique_id();
+ instruction_ids.insert(instruction->unique_id());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 26fd1b2438..3bc2d13781 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -253,7 +253,7 @@ class HloModule {
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names);
+ bool uniquify_identifiers);
const string name_;
HloModuleConfig config_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index 98d20315e3..f7be5cae22 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -36,23 +36,6 @@ namespace xla {
namespace {
-bool HasSendRecv(HloComputation* computation) {
- for (auto* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kSendDone ||
- instruction->opcode() == HloOpcode::kRecv ||
- instruction->opcode() == HloOpcode::kRecvDone) {
- return true;
- }
- for (auto* sub_computation : instruction->called_computations()) {
- if (HasSendRecv(sub_computation)) {
- return true;
- }
- }
- }
- return false;
-}
-
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
bool changed = false;
for (auto* computation : module->computations()) {
@@ -68,9 +51,10 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
while_body_root->opcode() != HloOpcode::kTuple ||
- HasSendRecv(while_body_comp)) {
+ while_body_comp->HasSideEffect() ||
+ xla_while->while_condition()->HasSideEffect()) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
- // with no send/recv instructions.
+ // with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 363862e490..bf66cc6bc3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
"while.2", 1));
}
+// Tests that a while whose body has outfeed operations is not DCE-ed.
+TEST_F(HloModuleDceTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
+// Tests that if a loop variable is not referenced outside of a kWhile, the loop
+// variable changes are not elided within the loop body, if the condition
+// computation uses them.
+TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
+ auto module = ParseHloString(R"(
+ HloModule InfiniteLoop
+ WhileBody {
+ body_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
+ }
+ WhileCondition {
+ cond_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ p0 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
+ while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc
new file mode 100644
index 0000000000..f9b56ef464
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+namespace xla {
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ std::unique_ptr<HloModule> module)
+ : name_(name) {
+ push_back(std::move(module));
+}
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules)
+ : name_(name) {
+ for (auto& module : modules) {
+ push_back(std::move(module));
+ }
+}
+
+std::vector<std::unique_ptr<HloModule>> HloModuleGroup::ConsumeModules() {
+ std::vector<std::unique_ptr<HloModule>> ret_modules = std::move(modules_);
+
+ // Clear everything so the object state is in a known (empty) state.
+ modules_.clear();
+ module_ptrs_.clear();
+ return ret_modules;
+}
+
+string HloModuleGroup::ToString() const {
+ std::ostringstream s;
+ s << "HloModuleGroup " << name() << "\n\n";
+ for (const HloModule* module : modules()) {
+ s << module->ToString() << "\n";
+ }
+ return s.str();
+}
+
+HloModuleGroupProto HloModuleGroup::ToProto() const {
+ HloModuleGroupProto proto;
+ proto.set_name(name());
+ for (const HloModule* module : modules()) {
+ *proto.add_hlo_modules() = module->ToProto();
+ }
+ return proto;
+}
+
+/* static */ StatusOr<HloModuleGroup> HloModuleGroup::CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs) {
+ TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty";
+ TF_RET_CHECK(proto.hlo_modules_size() > 0)
+ << "Module group must have at least one HLO module";
+ TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size());
+
+ std::vector<std::unique_ptr<HloModule>> modules;
+ for (int i = 0; i < proto.hlo_modules_size(); ++i) {
+ const HloModuleProto& module_proto = proto.hlo_modules(i);
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(module_proto, module_configs[i]));
+ modules.push_back(std::move(module));
+ }
+
+ return HloModuleGroup(proto.name(), absl::MakeSpan(modules));
+}
+
+void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
+ modules_.push_back(std::move(module));
+ module_ptrs_.push_back(modules_.back().get());
+}
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) {
+ out << group.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h
new file mode 100644
index 0000000000..7338be8b9c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+
+#include <iosfwd>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+
+// An abstraction representing a ordered set of HLO module built to run
+// concurrently across different devices.
+class HloModuleGroup {
+ public:
+ // Construct an empty module group.
+ explicit HloModuleGroup(absl::string_view name) : name_(name) {}
+
+ // Construct a module group containing a single module.
+ HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module);
+
+ // Construct a module group containing any number of modules.
+ HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules);
+
+ // Returns the modules contained in the group.
+ const std::vector<HloModule*>& modules() const { return module_ptrs_; }
+
+ // Returns a module at a particular index.
+ HloModule& module(int index) const { return *module_ptrs_.at(index); }
+
+ // Add a module to the back of vector of modules in the group.
+ void push_back(std::unique_ptr<HloModule> module);
+
+ // Moves all modules from the group into the returned vector. After this
+ // method runs, the module group will be empty.
+ std::vector<std::unique_ptr<HloModule>> ConsumeModules();
+
+ string name() const { return name_; }
+ string ToString() const;
+
+ // Serialize the module group to/from a proto.
+ HloModuleGroupProto ToProto() const;
+ static StatusOr<HloModuleGroup> CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs);
+
+ private:
+ string name_;
+
+ // Vector of modules as std::unique_ptrs.
+ std::vector<std::unique_ptr<HloModule>> modules_;
+
+ // Vector of modules as normal pointers. This vector is kept in sync with
+ // modules_ as modules are added to the group with push_back.
+ std::vector<HloModule*> module_ptrs_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
new file mode 100644
index 0000000000..ebf790ba6f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class HloModuleGroupTest : public HloTestBase {
+ protected:
+ HloModuleGroupTest() = default;
+};
+
+TEST_F(HloModuleGroupTest, SingleModule) {
+ const string text = R"(
+HloModule simple_module
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ HloModuleGroup group(TestName(), std::move(module));
+
+ EXPECT_EQ(group.modules().size(), 1);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 1);
+ EXPECT_THAT(
+ group_copy.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
+ EXPECT_EQ(modules.size(), 1);
+ EXPECT_EQ(group.modules().size(), 0);
+}
+
+TEST_F(HloModuleGroupTest, MultipleModules) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ std::vector<std::unique_ptr<HloModule>> modules;
+ modules.push_back(std::move(module_0));
+ modules.push_back(std::move(module_1));
+ HloModuleGroup group(TestName(), absl::MakeSpan(modules));
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config(),
+ group.module(1).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 2);
+}
+
+TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ HloModuleGroup group(TestName());
+ group.push_back(std::move(module_0));
+ group.push_back(std::move(module_1));
+
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 400bd4d947..39f38b417a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -253,6 +253,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
+ // Verify that serializing then deserializing an HLO proto preserves the
+ // unique IDs of the instruction and module.
+ const string text =
+ R"(HloModule ReduceR3ToR2_module
+
+add_F32.v3 {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY ReduceR3ToR2.v3 {
+ input = f32[8,16,256]{2,1,0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+
+ // Perform various transformations on the graph:
+ //
+ // * clone the reduction function
+ // * replace use of reduction function with the clone.
+ // * add a random instruction to the entry computation.
+ //
+ // This will create instruction and computation IDs which are interesting:
+ // not consecutive and not densely packed.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+ HloComputation* reduction = root->to_apply();
+ HloComputation* reduction_clone =
+ module->AddEmbeddedComputation(reduction->Clone());
+ root->set_to_apply(reduction_clone);
+ TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
+ HloInstruction* negate = entry->AddInstruction(
+ HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
+ entry->set_root_instruction(negate);
+
+ // Schedule the transformed module, this verifies that the serialized schedule
+ // is robust against non-consecutive IDs as well (b/114712358).
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ HloMemoryScheduler scheduler(size_fn);
+ TF_ASSERT_OK(scheduler.Run(module.get()).status());
+ ASSERT_TRUE(module->has_schedule());
+
+ // Serialize and deserialize and verify that the instruction and computations
+ // unique ids are the same.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+
+ // The module IDs should *not* be the same because module ids must be globally
+ // unique.
+ EXPECT_NE(module->unique_id(), module_copy->unique_id());
+
+ // Verify that the computations and instructions all have the same unique id.
+ auto computation_copy_it = module_copy->computations().begin();
+ for (const HloComputation* computation_orig : module->computations()) {
+ const HloComputation* computation_copy = *computation_copy_it++;
+ EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original computation %s != ID of deserialized "
+ "computation %s: %d != %d",
+ computation_orig->name(), computation_copy->name(),
+ computation_orig->unique_id(), computation_copy->unique_id());
+
+ auto instruction_copy_it = computation_copy->instructions().begin();
+ for (const HloInstruction* instruction_orig :
+ computation_orig->instructions()) {
+ const HloInstruction* instruction_copy = *instruction_copy_it++;
+ EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original instruction %s != ID of deserialized "
+ "instruction %s: %d != %d",
+ instruction_orig->name(), instruction_copy->name(),
+ instruction_orig->unique_id(), instruction_copy->unique_id());
+ }
+ }
+
+ // Verify that the next unique ID which the module would have handed out is
+ // greater than the unique id of any instruction.
+ int next_id = module_copy->NewUniqueInstructionId();
+ for (const HloComputation* computation : module_copy->computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ EXPECT_GT(next_id, instruction->unique_id());
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 6b6005e7a5..00970bcda3 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index c54360b063..11caa89c54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -105,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
- bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
- bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseLiteral(Literal* literal, const Shape& shape);
+ bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+ bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
- bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
- std::unique_ptr<Literal> literal;
+ Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
- std::vector<std::unique_ptr<Literal>> elements(
- ShapeUtil::TupleElementCount(shape));
+ std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
- linear_index++, literal->get())) {
+ linear_index++, literal)) {
return false;
}
lexer_.Lex();
@@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
- *literal = (*literal)->Relayout(shape.layout());
+ *literal = literal->Relayout(shape.layout());
return true;
}
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = absl::make_unique<Literal>(shape);
+ *literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
- if ((*literal)->sparse_element_count() + 1 ==
+ if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
- (*literal)->AppendSparseElement(index, value);
+ literal->AppendSparseElement(index, value);
}
- (*literal)->SortSparseElements();
+ literal->SortSparseElements();
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 585c95972b..d9848cee0b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
namespace xla {
namespace {
-class HloReachabilityTest : public HloTestBase {};
+class HloReachabilityTest : public HloVerifiedTestBase {};
TEST_F(HloReachabilityTest, Reachability) {
// Construct and test a reachability graph of the following form:
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 0a0a6a323e..bd6dd79b67 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -27,15 +27,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -1194,51 +1193,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(HloModule* module,
- HloSchedule* schedule,
- int64 memory_limit_bytes,
- RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The schedule is constructed entirely by this method.
- TF_RET_CHECK(schedule->empty());
-
+StatusOr<bool> HloRematerialization::Run(HloModule* module) {
VLOG(1) << "HloRematerialization() with memory limit of "
- << HumanReadableNumBytes(memory_limit_bytes);
+ << HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial schedule of HLO instructions.
- TF_ASSIGN_OR_RETURN(*schedule,
- ScheduleModule(*module,
- [this](const BufferValue& buffer) {
- return size_function_(buffer.shape());
- },
- scheduler_algorithm_));
- if (copy_insertion) {
- // We run a separate pass of copy elision here because the sequential
- // ordering from the HLO schedule allows for more copies to be eliminated.
- // TODO(b/80249101): Instead of a separate copy elision pass, use the
- // ordering from the HLO schedule directly for copy insertion.
- SequentialHloOrdering ordering(*schedule);
- TF_RETURN_IF_ERROR(
- copy_insertion->RemoveUnnecessaryCopies(ordering, module));
-
- // RemoveUnnecessaryCopies only considers interference when determining
- // whether it is legal to remove a copy. However, copies in the graph may be
- // necessary for other reason such as preventing a constant from being live
- // out of the graph. So run AddSpecialCaseCopies to re-insert these copies.
- // TODO(b/80249101): Break copy insertion into several passes and run each
- // one once in the regular HLO pipeline.
- TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module));
-
- // The passes above can add and remove copies, update the schedule to
- // account for these transformations. Newly added instructions will be
- // placed ASAP in the schedule.
- TF_RETURN_IF_ERROR(schedule->Update());
-
- TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
- SequentialHloOrdering(*schedule), module));
- }
-
+ TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
// Adjust memory limit to account for the output of the entry
@@ -1254,7 +1214,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
});
const int64 adjusted_memory_limit_bytes =
- memory_limit_bytes - module_output_size;
+ memory_limit_bytes_ - module_output_size;
VLOG(1) << "Adjusted memory limit accounting for output ("
<< HumanReadableNumBytes(module_output_size)
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
@@ -1263,13 +1223,14 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, schedule](const CallGraphNode& node) -> Status {
+ [this, module](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
- ComputePeakMemory(
- node.computation(),
- schedule->sequence(node.computation()).instructions()));
+ ComputePeakMemory(node.computation(),
+ module->schedule()
+ .sequence(node.computation())
+ .instructions()));
}
return Status::OK();
},
@@ -1287,9 +1248,10 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// Subcomputations called by the entry computation will also be
// rematerialized.
- TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), schedule,
- adjusted_memory_limit_bytes));
+ TF_ASSIGN_OR_RETURN(
+ bool changed,
+ RematerializeComputation(module->entry_computation(), &module->schedule(),
+ adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@@ -1298,7 +1260,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// After DCE, the module sequence may include instructions which no longer
// exist.
- TF_RETURN_IF_ERROR(schedule->Update());
+ TF_RETURN_IF_ERROR(module->schedule().Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1315,32 +1277,22 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
<< reduced_peak_memory << " bytes)";
- if (sizes != nullptr) {
- sizes->before_bytes = before_peak_memory;
- sizes->after_bytes = current_peak_memory;
+ if (sizes_ != nullptr) {
+ sizes_->before_bytes = before_peak_memory;
+ sizes_->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
- if (current_peak_memory > memory_limit_bytes) {
+ if (current_peak_memory > memory_limit_bytes_) {
LOG(WARNING) << absl::StrFormat(
"Can't reduce memory use below %s (%d bytes) by rematerialization; "
"only reduced to %s (%d bytes)",
- HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes,
+ HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
}
-/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
- const HloRematerialization::ShapeSizeFunction& size_function,
- int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
- HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes,
- copy_insertion);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index fa0414b472..e2aaf18b3e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,17 +17,23 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
namespace xla {
-class HloRematerialization {
+// HLO pass which rematerializes instructions to reduce peak memory use, where
+// memory use is defined as the total size of all live HLO instruction
+// values. Parameters and constants are included in memory use estimates.
+//
+// CSE will undo the effects of this optimization and should not be run after
+// this pass. In general, this pass should be run very late, immediately before
+// code generation.
+class HloRematerialization : public HloPassInterface {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
@@ -38,10 +44,7 @@ class HloRematerialization {
int64 after_bytes;
};
- // Rematerialize HLO instructions in the given module to reduce peak memory
- // use below memory_limit_bytes where memory use is defined as the total size
- // of all live HLO instruction values. Parameters and constants are included
- // in memory use estimates. Method parameters:
+ // Constructor parameters:
//
// size_function: Function which returns the size in bytes of the top-level
// buffer of the given shape.
@@ -49,51 +52,27 @@ class HloRematerialization {
// memory_limit_bytes: The threshold number of bytes to reduce memory use to
// via rematerialization.
//
- // hlo_module: HLO module to rematerialize instructions in.
- //
- // schedule: Should point to an empty HloSchedule. Upon return
- // contains the HLO instruction order which was used for
- // rematerialization. This is the order in which HLO instructions should
- // be emitted to minimize memory use.
- //
- // sizes: Optional outparam that indicates the peak memory usage of the HLO
- // module before/after rematerialization.
- //
- // copy_insertion: If non-null, run copy elision after scheduling. This
- // pass is used to eliminate copies that were inserted by copy insertion
- // before HLO scheduling.
- //
- // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
- // insertion is integrated with HLO scheduling.
- //
- // Returns whether any instructions were rematerialized. If memory use is
- // already below the given limit then no instructions are rematerialized and
- // false is returned.
- //
- // CSE will undo the effects of this optimization and should not be run after
- // this pass. In general, this pass should be run very late immediately before
- // code generation.
- static StatusOr<bool> RematerializeAndSchedule(
- const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
- HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- HloSchedule* schedule, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion = nullptr);
-
- protected:
- HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
- const ShapeSizeFunction& size_function)
- : scheduler_algorithm_(scheduler_algorithm),
- size_function_(size_function) {}
+ // sizes: Pointer to data structure which records the peak memory usage of
+ // the HLO module before/after rematerialization. Value are set during
+ // Run(). Can be nullptr.
+ HloRematerialization(const ShapeSizeFunction& size_function,
+ int64 memory_limit_bytes, RematerializationSizes* sizes)
+ : size_function_(size_function),
+ memory_limit_bytes_(memory_limit_bytes),
+ sizes_(sizes) {}
~HloRematerialization() {}
+ absl::string_view name() const override { return "rematerialization"; }
+
// Runs rematerialization on the given module. Returns whether the module was
- // changed. memory_limit is the target maximum peak memory usage by the
- // module. schedule should be an empty HloSchedule. Upon return sequence
- // contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module, HloSchedule* schedule,
- int64 memory_limit, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion);
+ // changed. Requires that the module has a schedule set
+ // (HloModule::has_schedule() is true) before running. Returns whether any
+ // instructions were rematerialized. If memory use is already below the limit
+ // specified in the constructor then no instructions are rematerialized and
+ // false is returned.
+ StatusOr<bool> Run(HloModule* module) override;
+ protected:
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
@@ -121,6 +100,14 @@ class HloRematerialization {
// Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
+ // The threshold number of bytes to reduce memory use to via
+ // rematerialization.
+ const int64 memory_limit_bytes_;
+
+ // Pointer to data structure which records the peak memory usage of the HLO
+ // module before/after rematerialization
+ RematerializationSizes* sizes_;
+
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83cb113bfb..f7e82fb1f8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
@@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase {
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
- HloModule* module,
- HloSchedule* schedule) {
+ HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
- return HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- schedule, /*sizes=*/nullptr);
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ DefaultMemoryScheduler);
+ TF_EXPECT_OK(scheduler.Run(module).status());
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr);
+ return remat.Run(module);
}
// Various shapes used in the canned computations.
@@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024, module));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024, module));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024, module));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024, module));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &schedule));
+ bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 66ac1f66fd..fa7f216321 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals) {
+ const absl::Span<const Literal> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
- literal_pointers.push_back(literal.get());
+ literal_pointers.push_back(&literal);
}
return TransferLiteralsToDevice(literal_pointers);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
+StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
TF_ASSIGN_OR_RETURN(
auto stream, backend().BorrowStream(backend().default_stream_executor()));
@@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
buffer);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+StatusOr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
@@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return TransferLiteralFromDevice(result);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
- argument_pointers.push_back(argument.get());
+ argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
@@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
/*profile=*/profile);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(
@@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = absl::make_unique<Literal>();
+ Literal literal;
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, options.outfeed_shape, literal.get()));
+ executor, options.outfeed_shape, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}
@@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
- std::vector<std::unique_ptr<Literal>> exec_results;
+ std::vector<Literal> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 76d8b92bed..2e934bf66a 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -72,7 +72,7 @@ class HloRunner {
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
- std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+ std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
@@ -106,24 +106,23 @@ class HloRunner {
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const absl::Span<const std::unique_ptr<Literal>> literals);
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- const ShapedBuffer& buffer);
+ const absl::Span<const Literal> literals);
+ StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const Literal* const> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal* const> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const absl::Span<const std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
@@ -140,7 +139,7 @@ class HloRunner {
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
- StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+ StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
index eb52582bb5..1424569ac1 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a1f2..6fd734a2b9 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -24,7 +24,7 @@ namespace {
using ::tensorflow::GraphDef;
-class HloTfGraphBuilderTest : public HloTestBase {
+class HloTfGraphBuilderTest : public HloVerifiedTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 069586a738..50f39cbcb5 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1123,6 +1123,11 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+ // If the module has a schedule, it must be valid.
+ if (module->has_schedule()) {
+ TF_RETURN_IF_ERROR(module->schedule().Verify());
+ }
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0cac210c24..8f0423bb1c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
padding_config));
auto module = CreateNewModule();
@@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
- builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(F32).CloneToUnique())),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
padding_config));
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 37b774b8a5..06f0e1ed25 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> inner_broadcast_result,
+ Literal inner_broadcast_result,
broadcast_const_operand->literal().Broadcast(
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
@@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+ opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+ opcode, inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 9746d176cc..df9cbab915 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
}
}
- Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+ Literal* TakeOwnership(Literal literal) {
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
- StatusOr<Literal*> TakeOwnership(
- StatusOr<std::unique_ptr<Literal>> literal_or_error) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- std::move(literal_or_error));
+ StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
std::vector<std::unique_ptr<Array>> owned_tensors_;
- std::vector<std::unique_ptr<Literal>> owned_literals_;
+ std::vector<Literal> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
};
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 5695bc2420..7e967f035c 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using InlinerTest = HloTestBase;
+using InlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
TEST_F(InlinerTest, MapMax) {
@@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
// Test that `constant` function is changed to `broadcast`.
@@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 8c907eae0c..3fdc2cee9a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible(
return do_not_duplicate;
}
+namespace {
+
+// A FusionQueue that uses reverse post order.
+//
+// We want to be able to remove arbitrary instructions from the post order and
+// also compare positions of instructions in the post order. To make this
+// possible, create vector of instructions in post order and create a map from
+// HloInstruction* to the instruction's index in the vector. An instruction is
+// "removed" from the vector by setting it's element to nullptr.
+class ReversePostOrderFusionQueue : public FusionQueue {
+ public:
+ explicit ReversePostOrderFusionQueue(HloComputation* computation) {
+ post_order_ = computation->MakeInstructionPostOrder();
+
+ for (size_t i = 0; i < post_order_.size(); ++i) {
+ InsertOrDie(&post_order_index_, post_order_[i], i);
+ }
+ }
+
+ std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() override {
+ // Instructions are "removed" from the post order by nulling out the element
+ // in the vector, so if the pointer is null, continue to the next
+ // instruction in the sort.
+ while (!post_order_.empty() && post_order_.back() == nullptr) {
+ post_order_.pop_back();
+ }
+ if (post_order_.empty()) {
+ return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
+ }
+ // We want to iterate in reverse post order, so remove from the back of the
+ // vector.
+ HloInstruction* instruction = post_order_.back();
+ post_order_.pop_back();
+
+ CHECK(instruction != nullptr);
+ // Remove instruction from the index map to ensure the vector and map stay
+ // consistent.
+ post_order_index_.erase(instruction);
+
+ // Consider each operand of this instruction for fusion into this
+ // instruction. We want to consider the operands in a particular order to
+ // avoid creating duplicate instruction clones in the fusion instruction.
+ // For example, consider the following expression:
+ //
+ // A = ...
+ // B = op(A)
+ // C = op(A, B)
+ //
+ // If we are considering the operands of C for fusion into C. We might
+ // fuse A or B first. If we fuse A first, we get:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // C' = op(A', B) }
+ //
+ // Where A' and C' are clones of A and C, respectively. Now only B is an
+ // operand of the fusion instruction C_fusion, so then we fuse B:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // B' = op(A)
+ // C' = op(A', B') }
+ //
+ // Now A is an operand of C_fusion again, so we then fuse A (again!):
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // A" = ..
+ // B' = op(A")
+ // C' = op(A', B') }
+ //
+ // We prevent this duplication by considering the operands in the order
+ // they appear int the queue. In the example, this ensures that B will be
+ // considered before A.
+ //
+ // We store the original indices of the operands to pass to ShouldFuse.
+ std::vector<int64> sorted_operand_numbers;
+ sorted_operand_numbers.reserve(instruction->operands().size());
+ for (int i = 0; i < instruction->operands().size(); ++i) {
+ // This will happen if we have two possible instructions to fuse the
+ // same operand into; once the operand is fused into one instruction,
+ // the other instruction will get a new get-tuple-element as its
+ // operand, which is not in the queue.
+ // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+ if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
+ continue;
+ }
+ sorted_operand_numbers.push_back(i);
+ }
+ std::sort(
+ sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+ [&](int64 i, int64 j) {
+ // Instructions with higher priority in the queue come first.
+ return (
+ FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
+ FindOrDie(post_order_index_, instruction->mutable_operand(j)));
+ });
+ return std::make_pair(instruction, sorted_operand_numbers);
+ }
+
+ void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) override {
+ // Fusing an instruction into a fusion instruction can change the operand
+ // set of the fusion instruction. For simplicity just re-enqueue the
+ // instruction and reconsider it for further fusion in the next iteration.
+ InsertOrDie(&post_order_index_, fusion, post_order_.size());
+ post_order_.push_back(fusion);
+ }
+
+ void RemoveInstruction(HloInstruction* instruction) override {
+ post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
+ post_order_index_.erase(instruction);
+ }
+
+ private:
+ std::vector<HloInstruction*> post_order_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+};
+
+} // namespace
+
+std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer) {
+ return absl::make_unique<ReversePostOrderFusionQueue>(computation);
+}
+
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
VLOG(2) << "Before instruction fusion:";
XLA_VLOG_LINES(2, module->ToString());
@@ -306,111 +439,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
computation_ = computation;
reachability_ = computation_->ComputeReachability();
- // We want to be able to remove arbitrary instructions from the post order
- // and also compare positions of instructions in the post order. To make
- // this possible, create vector of instructions in post order and create a
- // map from HloInstruction* to the instruction's index in the vector. An
- // instruction is "removed" from the vector by setting it's element to
- // nullptr.
- std::vector<HloInstruction*> post_order =
- computation_->MakeInstructionPostOrder();
-
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
- for (size_t i = 0; i < post_order.size(); ++i) {
- InsertOrDie(&post_order_index, post_order[i], i);
- }
-
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
+ HloInstructionSet do_not_duplicate =
+ ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
+ auto fusion_queue =
+ GetFusionQueue(computation_, [&](HloInstruction* producer) {
+ return do_not_duplicate.count(producer) > 0;
+ });
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
// edges. When we fuse an edge, we create a copy of the producer inside the
// fusion instruction.
- while (!post_order.empty()) {
- // We want to iterate in reverse post order, so remove from the back of
- // the vector.
- HloInstruction* instruction = post_order.back();
- post_order.pop_back();
-
- // Instructions are "removed" from the post order by nulling out the
- // element in the vector, so if the pointer is null, continue to the next
- // instruction in the sort.
+ while (true) {
+ auto next_entry =
+ fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
+ auto instruction = next_entry.first;
if (instruction == nullptr) {
- continue;
+ break;
}
- // Remove instruction from the index map to ensure the vector and map stay
- // consistent.
- post_order_index.erase(instruction);
-
if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- // Consider each operand of this instruction for fusion into this
- // instruction. We want to consider the operands in a particular order to
- // avoid creating duplicate instruction clones in the fusion instruction.
- // For example, consider the following expression:
- //
- // A = ...
- // B = op(A)
- // C = op(A, B)
- //
- // If we are considering the operands of C for fusion into C. We might
- // fuse A or B first. If we fuse A first, we get:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // C' = op(A', B) }
- //
- // Where A' and C' are clones of A and C, respectively. Now only B is an
- // operand of the fusion instruction C_fusion, so then we fuse B:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // B' = op(A)
- // C' = op(A', B') }
- //
- // Now A is an operand of C_fusion again, so we then fuse A (again!):
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // A" = ..
- // B' = op(A")
- // C' = op(A', B') }
- //
- // We prevent this duplication by considering the operands in the reverse
- // order they appear in the instruction post order. In the example, this
- // ensures that B will be considered before A.
- //
- // We store the original indices of the operands to pass to ShouldFuse.
- std::vector<int64> sorted_operand_numbers;
- sorted_operand_numbers.reserve(instruction->operands().size());
- for (int i = 0; i < instruction->operands().size(); ++i) {
- // This will happen if we have two possible instructions to fuse the
- // same operand into; once the operand is fused into one instruction,
- // the other instruction will get a new get-tuple-element as its
- // operand, which is not in the post-order index.
- // TODO(tjoerg): Look into fusing past these multi-output fuse points.
- if (post_order_index.find(instruction->mutable_operand(i)) ==
- post_order_index.end()) {
- continue;
- }
- sorted_operand_numbers.push_back(i);
- }
- std::sort(
- sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
- [&](int64 i, int64 j) {
- // Instructions with higher indices in the post order come
- // first.
- return (
- FindOrDie(post_order_index, instruction->mutable_operand(i)) >
- FindOrDie(post_order_index, instruction->mutable_operand(j)));
- });
+ std::vector<int64>& sorted_operand_numbers = next_entry.second;
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
@@ -425,32 +478,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// TODO(tjoerg): Consider making multi-output fusion the default.
if (ShouldFuse(instruction, i) &&
do_not_duplicate.count(operand) == 0) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = Fuse(operand, instruction);
} else if (ShouldFuseIntoMultiOutput(instruction, i) &&
!MultiOutputFusionCreatesCycle(operand, instruction)) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = FuseIntoMultiOutput(operand, instruction);
} else {
continue;
}
- // Fusing an instruction into a fusion instruction can change the
- // operand set of the fusion instruction. For simplicity just push the
- // instruction to the top of the post_order and reconsider it for
- // further fusion in the next iteration of the outer loop.
- post_order.push_back(fusion_instruction);
- InsertOrDie(&post_order_index, fusion_instruction,
- post_order.size() - 1);
+ fusion_queue->OnFusingInstruction(fusion_instruction, operand,
+ instruction);
changed = true;
if (operand->user_count() == 0) {
- // Operand is now dead. Remove from post order by setting its
- // location to nullptr.
- post_order[FindOrDie(post_order_index, operand)] = nullptr;
- post_order_index.erase(operand);
-
+ do_not_duplicate.erase(operand);
+ // Operand is now dead. Remove from queue.
+ fusion_queue->RemoveInstruction(operand);
// Remove from computation.
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
}
+
+ if (fusion_instruction != instruction) {
+ do_not_duplicate.erase(instruction);
+ }
break;
}
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 00b658959a..c1fde8ecfc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -24,6 +24,33 @@ limitations under the License.
namespace xla {
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
@@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface {
static bool IsExpensive(const HloInstruction& instruction);
protected:
+ // Returns a FusionQueue that implements custom order of instructions being
+ // fused. The default implementation processes consumers in reverse post
+ // order.
+ virtual std::unique_ptr<FusionQueue> GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer);
+
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
// Derived classes should define this method to specify which instructions
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 5dea124768..a06d6113e8 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// Transform the ShapedBuffer arguments into literals which the evaluator
// consumes.
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
// Execute the graph using the HloEvaluator.
- std::unique_ptr<Literal> result_literal;
+ Literal result_literal;
{
tensorflow::mutex_lock lock(evaluator_lock_);
- TF_ASSIGN_OR_RETURN(result_literal,
- evaluator_->Evaluate<std::unique_ptr<Literal>>(
- *computation, arg_literals));
+ TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
+ *computation, arg_literals));
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
transfer_manager->AllocateScopedShapedBuffer(
- result_literal->shape(), run_options->allocator(),
+ result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- run_options->stream(), *result_literal, result));
+ run_options->stream(), result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 69c7e42601..752a61476d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,7 +49,7 @@ namespace {
using ::testing::ElementsAre;
-class LayoutAssignmentTest : public HloTestBase {
+class LayoutAssignmentTest : public HloVerifiedTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout,
@@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
@@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
- Shape ashape = constant_literal1->shape();
+ Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
@@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
@@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
@@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
@@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
EXPECT_TRUE(
AlgebraicSimplifier(/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return false; })
- .Run(module.get())
+ .Run(module)
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
@@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
@@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
@@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_4, "param"));
auto broadcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_34, param, {3}));
+ HloInstruction::CreateBroadcast(f32_34, param, {1}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
auto broadcast2 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+ HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
auto module = CreateNewModule();
@@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
+ EXPECT_IS_OK(layout_assignment.Run(module).status());
EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
- module =
+ std::unique_ptr<HloModule> compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
EXPECT_EQ(Status::OK(), backend()
.compiler()
- ->RunBackend(std::move(module),
+ ->RunBackend(std::move(compiled_module),
backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.status());
@@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
ShapeUtil::MakeTupleShape({
@@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({2, 1, 0}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(&module(), &computation_layout);
- EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(1)
.layout()
@@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
HloComputation* computation = module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
const HloInstruction* true_root = true_computation->root_instruction();
const HloInstruction* false_root = false_computation->root_instruction();
@@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
- Status error_status = layout_assignment.Run(module.get()).status();
+ Status error_status = layout_assignment.Run(module).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
@@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
@@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
- AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+ AssignLayouts(&module(), &computation_layout, &channel_constraints);
- EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
- EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::GetSubshape(
- FindInstruction(module.get(), "send")->shape(), {0}),
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+ EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
@@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest,
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
@@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
auto compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index f0e2566a3f..b27a92f2a0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, *argument));
- *module->add_arguments() = literal->ToProto();
+ *module->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, result));
- *module->mutable_result() = literal->ToProto();
+ *module->mutable_result() = literal.ToProto();
return Status::OK();
}
@@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
@@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> result_literal,
+ Literal result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
- if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
- result_literal->shape())) {
- *result->mutable_literal() = result_literal->ToProto();
+ if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
+ *result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
- result_literal->Relayout(*return_shape)->ToProto();
+ result_literal.Relayout(*return_shape).ToProto();
}
return Status::OK();
}
@@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
@@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- stream.get(), *literal, shaped_buffer));
+ stream.get(), literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
- executor, *literal);
+ return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
+ literal);
}
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
- executor, arg->shape_with_layout(), *literal));
- *result->mutable_literal() = literal->ToProto();
+ executor, arg->shape_with_layout(), literal));
+ *result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModule::CreateFromProto(arg->computation(), config));
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(auto result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, /*arg_literals=*/{}));
+ TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
+ *module, /*arg_literals=*/{}));
// Since the result layout is non-effective to the Evaluator results, explicit
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
- result_literal = result_literal->Relayout(arg->output_layout());
+ result_literal = result_literal.Relayout(arg->output_layout());
}
- *result->mutable_literal() = result_literal->ToProto();
+ *result->mutable_literal() = result_literal.ToProto();
return Status::OK();
}
@@ -1162,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
return replicas;
}
-Status Service::MaybeDumpHloModule(const HloModule& module) const {
+Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const {
const string xla_dump_unoptimized_hlo_proto_to =
module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
if (xla_dump_unoptimized_hlo_proto_to.empty()) {
@@ -1170,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const {
}
HloProto proto = MakeHloProto(module);
return protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_unoptimized_hlo_proto_to, module.name());
+ proto, xla_dump_unoptimized_hlo_proto_to,
+ StrCat(module.name(), ".unoptimized"));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 44c5248b15..1f62fad4c8 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -271,7 +271,9 @@ class Service : public ServiceInterface {
StatusOr<std::vector<se::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
- Status MaybeDumpHloModule(const HloModule& module) const;
+ // Dumps the (unoptimized) module given if the corresponding DebugOptions
+ // field has been set.
+ Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
deleted file mode 100644
index dd53c7531b..0000000000
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/source_map_util.h"
-
-#include "absl/strings/str_format.h"
-#include "tensorflow/compiler/xla/util.h"
-
-namespace xla {
-namespace source_map_util {
-namespace {
-
-Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
- const char* format, va_list args) {
- string message;
- tensorflow::strings::Appendv(&message, format, args);
- if (!op_metadata.source_file().empty()) {
- absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
- op_metadata.source_line());
- }
- return InvalidArgument("%s", message);
-}
-
-} // namespace
-
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- Status result = InvalidParameterArgumentV(op_metadata, format, args);
- va_end(args);
- return result;
-}
-
-Status InvalidParameterArgument(Executable* executable, int parameter_number,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- if (executable != nullptr && executable->has_module()) {
- const HloModule& module = executable->module();
- const HloComputation& computation = *module.entry_computation();
- HloInstruction* param = computation.parameter_instruction(parameter_number);
- const OpMetadata& metadata = param->metadata();
- Status result = InvalidParameterArgumentV(metadata, format, args);
- va_end(args);
- return result;
- }
- Status result = InvalidArgumentV(format, args);
- va_end(args);
- return result;
-}
-
-} // namespace source_map_util
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index b8d2d546e5..a21e586efa 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
@@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferLiteralFromDevice(
@@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
return substream->BlockHostUntilDone();
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+StatusOr<Literal> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
@@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return absl::make_unique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferArrayToDevice(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 21725946b3..f952e64af2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -57,7 +57,7 @@ class TransferManager {
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
- virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ virtual StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
virtual Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
@@ -113,9 +113,9 @@ class TransferManager {
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source);
+ StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
+ const Shape& shape,
+ const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2b2a2eb42a..e9a07b14ed 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+ Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})};
+ auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index 39b693872d..516754e211 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class TupleSimplifierTest : public HloTestBase {
+class TupleSimplifierTest : public HloVerifiedTestBase {
protected:
void Run(HloModule* module, bool change_expected) {
TupleSimplifier simplifier;
@@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
@@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTuple) {
@@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
EXPECT_THAT(computation->root_instruction(), gte);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param1);
}
@@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
EXPECT_THAT(computation->root_instruction(),
op::Negate(op::GetTupleElement(op::Tuple())));
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
}
@@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
EXPECT_THAT(computation->root_instruction(), element);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), tuple_param);
}
@@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
EXPECT_THAT(computation->root_instruction(), tuple);
}
@@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
entry = module->AddEntryComputation(builder.Build());
}
- Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+ Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
EXPECT_THAT(c0->root_instruction(), p0);
EXPECT_THAT(c1->root_instruction(), p1);
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index c3c2603c7e..541b117e02 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
+ StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
- std::unique_ptr<Literal> indvar_iter_val =
- std::move(indvar_init_result).ValueOrDie();
+ Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(
- while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
+ if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
- StatusOr<std::unique_ptr<Literal>> indvar_next_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update,
- {{while_body_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();