diff options
Diffstat (limited to 'tensorflow/compiler/xla/service')
107 files changed, 2461 insertions, 1922 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e784663ff6..fb80c78f68 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,6 +87,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -123,6 +124,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -352,6 +354,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -402,6 +405,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -498,6 +502,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -546,6 +551,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -568,6 +574,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1012,8 +1019,8 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_memory_scheduler", ":hlo_proto", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1041,8 +1048,8 @@ tf_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1088,8 +1095,8 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dataflow_analysis", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1131,6 +1138,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1139,6 +1147,37 @@ tf_cc_test( ) cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_module_group_test", + srcs = ["hlo_module_group_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_module_group", + ":hlo_parser", + ":hlo_proto", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( name = "hlo_module_group_metadata", srcs = ["hlo_module_group_metadata.cc"], hdrs = ["hlo_module_group_metadata.h"], @@ -1185,9 +1224,9 @@ tf_cc_test( ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1199,13 +1238,14 @@ tf_cc_test( ) cc_library( - name = "hlo_scheduling", - srcs = ["hlo_scheduling.cc"], - hdrs = ["hlo_scheduling.h"], + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], deps = [ ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1219,15 +1259,15 @@ cc_library( ) tf_cc_test( - name = "hlo_scheduling_test", - srcs = ["hlo_scheduling_test.cc"], + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], deps = [ ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1259,6 +1299,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1392,6 +1433,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1708,6 +1750,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1777,6 +1820,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -1953,6 +1997,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", + ":hlo_memory_scheduler", ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2236,6 +2281,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2314,6 +2360,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2394,12 +2441,11 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", - ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2428,6 +2474,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2494,6 +2541,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2611,6 +2659,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2888,6 +2937,7 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 3d18fe3be2..5458159d14 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr<bool> FoldConvInputPad(HloInstruction* convolution); + StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -527,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + HloInstruction::CreateConstant(literal.Clone())); } } @@ -546,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>( + Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -676,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + HloInstruction::CreateConstant((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -1469,7 +1478,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { auto* iota = Cast<HloIotaInstruction>(instruction); if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(iota->shape().element_type()).Clone())); return ReplaceWithNewInstruction( iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); } @@ -1572,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - LiteralUtil::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr<HloInstruction> ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1607,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2057,12 +2066,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (pad_literal == reduce_init_literal) { return true; } - auto converted_pad_literal = pad_literal.ConvertToShape( - reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + auto converted_pad_literal = + pad_literal.ConvertToShape(reduce_init_value->shape()); if (!converted_pad_literal.ok()) { return false; } - return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + return converted_pad_literal.ValueOrDie() == reduce_init_literal; }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. @@ -2212,170 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()) - .CloneToUnique())), - {})); - } - + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - // Try to merge padding/dilation of the input with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> { - if (lhs->opcode() != HloOpcode::kPad) { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(lhs->operand(1), 0)) { + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { return false; } - const auto& padding = lhs->padding_config(); - - // Can't pad batch or feature dims. - for (int64 dim : - {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); } + } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = window; - for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); - // Edge padding composes with itself in the straightforward way, but - // composing interior padding is nontrivial, and we cowardly refuse to - // think about it. If we see interior padding in either the kPad or conv, - // bail if there's any sort of padding in the other. - if (p.interior_padding() != 0 && - (w.padding_low() != 0 || w.padding_high() != 0 || - w.base_dilation() != 1)) { - return false; - } - if (w.base_dilation() != 1 && - (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0)) { - return false; - } + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - w.set_padding_low(w.padding_low() + p.edge_padding_low()); - w.set_padding_high(w.padding_high() + p.edge_padding_high()); - if (p.interior_padding() != 0) { - CHECK_EQ(w.base_dilation(), 1); - w.set_base_dilation(1 + p.interior_padding()); - } - } +StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs->mutable_operand(0), rhs}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } - if (folded_input_pad) { - return Status::OK(); + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; } - // Try to merge dilation of the filter with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> { - if (rhs->opcode() != HloOpcode::kPad) { - return false; - } + const auto& padding = rhs->padding_config(); - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(rhs->operand(1), 0)) { + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - const auto& padding = rhs->padding_config(); + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - // Can't pad or dilate feature dims. - for (int64 dim : {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = convolution->window(); - for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - - // We can only do this transformation if p adds dilation to the filter -- - // edge padding on the filter is not supported in conv. - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { - return false; - } - - // Nothing to do if the kPad for this dim is entirely a nop. - if (p.interior_padding() == 0) { - continue; - } + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } - // We cowardly refuse to think about how dilation composes with itself; - // bail if both the kPad and conv have dilation on this dimension. - if (w.window_dilation() > 1) { - return false; - } - CHECK_EQ(w.window_dilation(), 1); - w.set_window_dilation(1 + p.interior_padding()); - w.set_size(rhs->operand(0)->shape().dimensions( - dnums.kernel_spatial_dimensions(dim))); + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs, rhs->mutable_operand(0)}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - if (folded_filter_pad) { - return Status::OK(); - } +StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } const Shape& input_shape = lhs->shape(); @@ -2388,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2399,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2425,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2467,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2479,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, convolution->precision_config())); - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index a0db4563fb..3fc1ba2427 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr<Literal> value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0<float>(constant_scalar).get(), - LiteralUtil::CreateR1<float>(constant_vector).get()}); + Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar), + LiteralUtil::CreateR1<float>(constant_vector)}; + Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec281ae68f..30d33e0d35 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( const Shape feature_shape = scale->shape(); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, computation_->AddInstruction( @@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add( @@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto elements_per_feature_literal = LiteralUtil::CreateR0<float>(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); + elements_per_feature_literal.Convert(ptype)); auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aba0d9bb5b..f7ac8f5482 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +using BatchNormExpanderTest = HloVerifiedTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + ParseAndVerifyModule(module_str); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); - for (auto* instruction : module->entry_computation()->instructions()) { + for (auto* instruction : module().entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 6363a21c3b..5f93740887 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { protected: + BFloat16ConversionFoldingTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); @@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 933cf873e0..cef0eba14e 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloTestBase { +class BFloat16NormalizationTest : public HloVerifiedTestBase { protected: + BFloat16NormalizationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16Normalization normalization(&bfloat16_support_); @@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module.get())); + EXPECT_FALSE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 545a6ecfb1..58f78f8e24 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN( - auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape(), - /*round_f32_to_bf16=*/true)); + TF_ASSIGN_OR_RETURN(auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 388fd5df99..e032b5c624 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 0f0af57626..65fa951afe 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 5a231c173d..795beb9ff5 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -30,11 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); + Literal elements[] = {LiteralUtil::CreateR0<int64>(0), + LiteralUtil::CreateR0<int64>(1)}; builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(), - LiteralUtil::CreateR0<int64>(1).get()}))); + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 414bfe7999..17e5090505 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(), - LiteralUtil::CreateR0<int64>(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()}); + Literal elements0[] = {LiteralUtil::CreateR0<int64>(0), + LiteralUtil::CreateR0<int64>(1)}; + auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]}); + Literal element1 = LiteralUtil::CreateR0<int64>(3); + auto inner_tuple1 = LiteralUtil::MakeTuple({&element1}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inner_tuple0->shape(), tuple_constant, 0)); + inner_tuple0.shape(), tuple_constant, 0)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index cc80b74843..34f3f914d5 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloTestBase { +class CallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr<HloComputation> MakeScalarComputation( @@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(3, call_graph->nodes().size()); @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); std::vector<HloComputation*> visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); // Test visitation of only reachable nodes. { @@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. auto module = CreateNewModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 5d85a3f173..e6b5665435 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloTestBase; +using CallInlinerTest = HloVerifiedTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(), @@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { HloComputation::Builder call_false_builder(TestName() + ".call_false"); call_false_builder.AddInstruction( + HloInstruction::CreateParameter(0, pred, "param")); + call_false_builder.AddInstruction( HloInstruction::CreateCall(pred, {}, false_computation)); HloComputation* call_false = module->AddEmbeddedComputation(call_false_builder.Build()); @@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index e5a6c28478..96bd2616f5 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModule> hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0826380f65..0ac4a65ec6 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>( - LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); auto new_filter = add( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2368ac8c6a..8cc522a59e 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -122,7 +122,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", @@ -801,6 +801,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -822,6 +823,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -946,6 +948,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -971,6 +974,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 05792795a1..2083f440fd 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloTestBase { +class ConvCanonicalizationTest : public HloVerifiedTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e7b6075994..18fc144efe 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -77,12 +77,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 4db7fa446e..c9fb34be1c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloTestBase { +class CpuCopyInsertionTest : public HloVerifiedTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*module), 3); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index 0f463e6de6..be1208fb2d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloTestBase { +class CpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 942e2ddd39..55d5925642 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -37,21 +37,20 @@ int main(int argc, char** argv) { xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); // Transfer parameters. - std::unique_ptr<xla::Literal> param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr<xla::GlobalData> param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> param1_literal = - xla::LiteralUtil::CreateR2<float>( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr<xla::GlobalData> param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p1, p0, {0}); xla::StatusOr<xla::XlaComputation> computation_status = builder.Build(); @@ -59,17 +58,16 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr<std::unique_ptr<xla::Literal>> result = - client->ExecuteAndTransfer( - computation, - /*arguments=*/{param0_data.get(), param1_data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/&profile); - std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie(); + xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*execution_options=*/nullptr, + /*execution_profile=*/&profile); + xla::Literal actual = result.ConsumeValueOrDie(); LOG(INFO) << absl::StrFormat("computation took %dns", profile.compute_time_ns()); - LOG(INFO) << actual->ToString(); + LOG(INFO) << actual.ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 7d8e51f909..1a3d82de95 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include <random> #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloTestBase { +class ShapePartitionAssignerTest : public HloVerifiedTestBase { protected: typedef std::vector<int64> Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloTestBase { +class ShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector<std::pair<int64, int64>> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloTestBase { +class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector<std::pair<int64, int64>> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index f11aff0573..c55206eee7 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,6 +48,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 22721051e5..1deb412064 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloTestBase { +class CpuFusionTest : public HloVerifiedTestBase { protected: CpuFusionTest() {} @@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}); auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0}); - Shape vshape = input_literal1->shape(); + Shape vshape = input_literal1.shape(); auto input1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal1))); @@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_); + LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_); } TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -122,11 +122,10 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result, - error_spec_); + LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_); } TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { @@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { << fusion_instruction2->fused_instructions_computation()->ToString(); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, - *result, error_spec_); + result, error_spec_); } TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { @@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index c35569c661..5cc6d01c0f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(), - LiteralUtil::CreateR0<bool>(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1<uint32>({1, 2, 3}), + LiteralUtil::CreateR0<bool>(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { // Send 5 Infeed data of shape F32[3]. ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3}))); + client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6}))); + client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9}))); + client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12}))); + client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15}))); + client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 3 infeed data should be added. - LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7}); } // Tests two Infeed operations with a total order. The order is enforced by @@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(), - LiteralUtil::CreateR0<bool>(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}), + LiteralUtil::CreateR0<bool>(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(), - LiteralUtil::CreateR0<bool>(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}), + LiteralUtil::CreateR0<bool>(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(), - LiteralUtil::CreateR0<bool>(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}), + LiteralUtil::CreateR0<bool>(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(), - LiteralUtil::CreateR0<bool>(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}), + LiteralUtil::CreateR0<bool>(false)}))); // Asynchronously launch the execution on the device. std::unique_ptr<GlobalData> result; @@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(), - LiteralUtil::CreateR0<bool>(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}), + LiteralUtil::CreateR0<bool>(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(), - LiteralUtil::CreateR0<bool>(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}), + LiteralUtil::CreateR0<bool>(false)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(), - LiteralUtil::CreateR0<bool>(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}), + LiteralUtil::CreateR0<bool>(true)}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 6 infeed data should be added. - LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7}); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index bb105194f1..7af51db55a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {}; TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); - std::unique_ptr<Literal> literal = - LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 1b3be199f6..852f34e06d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -56,9 +56,9 @@ ENTRY main { } )"; - std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}}); - std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}}); - RunTest(hlo_text, {lhs.get(), rhs.get()}); + Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}}); + Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}}); + RunTest(hlo_text, {&lhs, &rhs}); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8f6608241e..5fbd73a536 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr<HloComputation> MakeScalarComputation() { @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 4ed91ef187..bec02e14f9 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_memory.size()); // Element is array-shaped: transfer array data to device buffer. const auto subliteral = LiteralSlice(literal, index); - std::unique_ptr<Literal> relayed_out_literal; + Literal relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { @@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->untyped_data(); + source = relayed_out_literal.untyped_data(); TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, /*size=*/GetByteSizeRequirement(device_subshape), source, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6791e15ee0..64b9683628 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -173,6 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", @@ -370,6 +372,8 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ + ":backend_configs", + ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -395,6 +399,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -813,9 +818,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], ) @@ -832,6 +837,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -901,6 +907,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 05448d863d..3a23ac1d63 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" @@ -30,62 +31,32 @@ namespace gpu { using se::dnn::AlgorithmDesc; -ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, - int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - feature_group_count_(feature_group_count), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} - Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + CudnnConvParams params; + + params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); + params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); + params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, - feature_group_count_, algorithm_config, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. void* result_ptr = [&] { - switch (convolution_kind_) { + switch (params.kind) { case CudnnConvKind::kForward: - return output_data.opaque(); + return params.output_buf.opaque(); case CudnnConvKind::kBackwardInput: - return input_data.opaque(); + return params.input_buf.opaque(); case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); + return params.filter_buf.opaque(); } }(); void* ptrs[] = {result_ptr, scratch.opaque()}; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 68d67c40c5..d7d1f91fba 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // // Note that "output" here doesn't refer to the output from running this // thunk, but rather to the "output" of a hypothetical forward convolution // that corresponds to this input+filter+output triple. That is, the result // generated by this thunk is "output" for forward convs, "input" for // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, - int64 feature_group_count, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + BufferAllocation::Slice input_slice, + BufferAllocation::Slice filter_slice, + BufferAllocation::Slice output_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + input_buffer_(std::move(input_slice)), + filter_buffer_(std::move(filter_slice)), + output_buffer_(std::move(output_slice)), + scratch_buffer_(std::move(scratch_slice)), + tuple_result_buffer_(std::move(tuple_result_slice)) {} ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 feature_group_count_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + BufferAllocation::Slice input_buffer_; + BufferAllocation::Slice filter_buffer_; + BufferAllocation::Slice output_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5c2555148a..f528e62b17 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/mutex.h" @@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // caching would speed up compilation a lot. StatusOr<std::tuple<int64, bool, int64>> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr) { + const HloCustomCallInstruction* instr) { + CudnnConvParams params; + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, ¶ms)); + + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); CHECK_EQ(input_shape.element_type(), output_shape.element_type()); // TODO(timshen): for now only check fp16. It can be expanded to other types, @@ -216,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( allocator = &*se_allocator; } - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(output_shape))); - - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); size_t left_over_bytes = buffer.size() % 4; CHECK_EQ(0, left_over_bytes % 2); @@ -252,33 +244,46 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( DeviceMemoryBase left_over( static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - }; - initialize_f16(input_buf); - initialize_f16(filter_buf); - initialize_f16(output_buf); - } else { - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the - // inputs might affect performance if e.g. the inputs contain denormals, and - // this is easy enough. - stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()); - } + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); + } + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + TF_ASSIGN_OR_RETURN(params.input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(params.filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(params.output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); + + initialize_buffer(params.input_buf); + initialize_buffer(params.filter_buf); + initialize_buffer(params.output_buf); DeviceMemoryBase* result_buf = [&] { - switch (kind) { + switch (params.kind) { case CudnnConvKind::kBackwardFilter: - return &filter_buf; + return ¶ms.filter_buf; case CudnnConvKind::kBackwardInput: - return &input_buf; + return ¶ms.input_buf; case CudnnConvKind::kForward: - return &output_buf; + return ¶ms.output_buf; } }(); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, dnums, stream_exec_); + input_shape, output_shape, *params.dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; @@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // this algorithm considered correct, though. optional<AlgorithmDesc> first_algorithm; for (const AlgorithmDesc& alg : - GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = - RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, - filter_buf, output_buf, &scratch_allocator, window, dnums, - feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + params.algorithm = AlgorithmConfig(alg); + bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream, + &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { const bool crash_on_checking_failure = @@ -374,34 +377,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - const auto& call_target = instr->custom_call_target(); - const auto& lhs_shape = instr->operand(0)->shape(); - const auto& rhs_shape = instr->operand(1)->shape(); - const auto& conv_result_shape = instr->shape().tuple_shapes(0); - StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc; - if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = - PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, instr->window(), - instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr->feature_group_count(), - instr); - } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instr->ToString(); - } + StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc = + PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr)); if (!alg_scratch_and_tc.ok()) { LOG(ERROR) << alg_scratch_and_tc.status(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 0cb01161b0..f79b113f8f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr<bool> RunOnComputation(HloComputation* computation); StatusOr<bool> RunOnInstruction(HloInstruction* instr); StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr); + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 9bf721ecd2..228379a248 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include <cstdlib> #include <numeric> #include <vector> @@ -59,8 +60,6 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward filter. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -218,13 +217,16 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( - HloInstruction* conv) { +std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*> +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward input. + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -232,51 +234,38 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else if (reverse_filter->IsConstant()) { - // If the filter is a constant, we're willing to pattern-match to a - // backwards-input conv, on the theory that - // - // a) reversing a constant is free, and - // b) even if the user specified this filter as reverse(constant), we would - // long ago have constant-folded away the reverse. - // - // If the constant has any other uses, reversing it isn't entirely free, - // since we'd now have two constants to keep in memory. But hopefully it's - // free enough. - // - // TODO(jlebar): Should we do this even if the filter is not a constant? - // Reversing a non-constant filter is probably cheaper than padding the - // input! - - // Nothing to do, just fall through. - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -401,26 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( } } - // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. - // This simplifies things for our caller, and algebraic-simplifier will later - // remove any unnecessary reverses. - if (reverse_filter->opcode() != HloOpcode::kReverse) { + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { // Create a double-reverse, which is a nop. HloComputation* c = conv->parent(); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } - dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); - dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); } // Tries to rewrite a single convolution into a call to cudnn. @@ -431,6 +458,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { @@ -439,13 +467,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput(conv->shape(), conv->mutable_operand(0), rhs, window, dnums, conv->feature_group_count()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index bda8ebe579..d237f8930b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { Array4D<float> constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); ParseAndVerifyModule(absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 05125e9d1f..2a86ac265e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template <typename T> -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory<T> input_buf, - DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - AlgorithmConfig algorithm, Stream* stream, - ProfileResult* profile_result /*= nullptr*/) { +Status RunCudnnConvolutionImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory<T> input_buf(params.input_buf); + DeviceMemory<T> filter_buf(params.filter_buf); + DeviceMemory<T> output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, - output_buf, &scratch_allocator, window, dnums, feature_group_count, - algorithm, stream, profile_result); + return RunCudnnConvolution(params, &scratch_allocator, stream, + profile_result); } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = params.output_shape->element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<Eigen::half>(input_buf), - se::DeviceMemory<Eigen::half>(filter_buf), - se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator, + stream, profile_result); case F32: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<float>(input_buf), - se::DeviceMemory<float>(filter_buf), - se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums, - feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream, + profile_result); case F64: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<double>(input_buf), - se::DeviceMemory<double>(filter_buf), - se::DeviceMemory<double>(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream, + profile_result); default: - LOG(FATAL) << ShapeUtil::HumanString(output_shape); + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index a1b4fc71d0..381aa37a1b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -47,6 +47,20 @@ enum class CudnnConvKind { kBackwardFilter, // input + output => filter }; +struct CudnnConvParams { + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; +}; + // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); @@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind); // Note that depending on the value of CudnnConvKind, the result of this call // may be written into input_buf, filter_buf, or output_buf! // -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. +// At the moment convolution with half data type is implemented with cudnn +// PSEUDO_HALF configuration, that is, the input values are half and the +// internal computation type is float. // // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In @@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index ea9376e101..02a0d028c1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 59ade96f7d..b857fa775a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,14 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloVerifiedTestBase { protected: using HloVec = std::vector<const HloInstruction*>; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 0a4089df4c..27a4d0b601 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloTestBase { +class GpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 20d523abe0..22f43bc08b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params) { + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config<CudnnConvBackendConfig>()); + const auto& target = custom_call->custom_call_target(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + + params->window = &custom_call->window(); + params->dnums = &custom_call->convolution_dimension_numbers(); + params->feature_group_count = custom_call->feature_group_count(); + params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + + if (target == kCudnnConvForwardCallTarget) { + params->kind = CudnnConvKind::kForward; + params->input_shape = &lhs_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &conv_result_shape; + } else if (target == kCudnnConvBackwardInputCallTarget) { + params->kind = CudnnConvKind::kBackwardInput; + params->input_shape = &conv_result_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &lhs_shape; + } else if (target == kCudnnConvBackwardFilterCallTarget) { + params->kind = CudnnConvKind::kBackwardFilter; + params->input_shape = &lhs_shape; + params->filter_shape = &conv_result_shape; + params->output_shape = &rhs_shape; + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59c65fc268..09c455cc1e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -20,7 +20,9 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Populates params using conv, which must be a custom-call to a cudnn +// convolution. Does not modify any buffers in the params. +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f91cc00d71..b669881026 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config<CudnnConvBackendConfig>()); const auto& target = custom_call->custom_call_target(); - std::unique_ptr<ConvolutionThunk> thunk; + BufferAllocation::Slice input_slice, filter_slice, output_slice; + if (target == kCudnnConvForwardCallTarget) { - thunk = absl::make_unique<ConvolutionThunk>( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = rhs_slice; + output_slice = conv_result_slice; } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = absl::make_unique<ConvolutionThunk>( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = conv_result_slice; + filter_slice = rhs_slice; + output_slice = lhs_slice; } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = absl::make_unique<ConvolutionThunk>( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = conv_result_slice; + output_slice = rhs_slice; } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); } - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>( + Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice, + output_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f6325b3368..dfdcf1875d 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass<CudnnConvolutionRewriter>(); - // CudnnConvolutionRewriter may add instructions of the form - // reverse(constant), which it expects will be simplified by constant - // folding. - pipeline.AddPass<HloConstantFolding>(); pipeline.AddPass<PadInsertion>(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass<PadForTensorCores>(); @@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // pairs that TupleSimplifier fixes. pipeline.AddPass<TupleSimplifier>(); } + // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add + // instructions which can be simplified by constant folding. + pipeline.AddPass<HloConstantFolding>(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index fa84d77223..b0061fa655 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -23,7 +23,6 @@ limitations under the License. namespace xla { namespace gpu { - // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. static constexpr int64 kDesiredNumFeaturesFactor = 8; @@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9d85d746d8..2a6415d0b6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput( conv_window.dimensions(i).base_dilation() - 1); } PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique<Literal>(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloComputation* computation = kernel->parent(); PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique<Literal>(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique<Literal>( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8f0dedfa40..c4f43cc9a6 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,14 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloTestBase { +class StreamAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr<HloModule> CreateNewModule() { HloModuleConfig config; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 4550f36fdf..780539c164 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {}; TEST_F(GpuCopyTest, UseMemcpy) { HloComputation::Builder builder(TestName()); - std::unique_ptr<Literal> literal = - LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index 9072b30317..f8120a5fa0 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } @@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) { TEST_F(InfeedTest, LargeInfeed) { Array4D<float> array(80, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array)); + TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array)); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(), - LiteralUtil::CreateR0<bool>(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1<uint32>({1, 2, 3}), + LiteralUtil::CreateR0<bool>(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests that a large tuple infeed can be handled. TEST_F(InfeedTest, SingleInfeedLargeTuple) { Array4D<float> array(40, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR4FromArray4D<float>(array).get(), - LiteralUtil::CreateR0<int32>(5).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR4FromArray4D<float>(array), + LiteralUtil::CreateR0<int32>(5)})); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 40183de96e..9a61f8ac5a 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -26,9 +26,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::Eq; -using ::testing::HasSubstr; - class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 00a25db467..957c4a6891 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module.get()); + HloSchedule schedule(module); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -233,7 +233,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloTestBase { +class HeapSimulatorTest : public HloVerifiedTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 93ec2c9438..b19ec12638 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -309,6 +309,13 @@ message HeapSimulatorTrace { bool whole_module_simulation = 2; } +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 233d2199d1..8c6903d766 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -562,9 +562,11 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + auto computation = absl::WrapUnique( + new HloComputation(proto.name(), parameter_count, &instructions, root, + /*fusion_instruction=*/nullptr)); + computation->unique_id_ = proto.id(); + return std::move(computation); } void HloComputation::FuseInstructionsInto( diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 8a45939c61..f837816cea 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { continue; } - std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction); + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. - if (result == nullptr) { + if (!evaluator->TryEvaluate(instruction, &result)) { VLOG(2) << "Constant folding failed for instruction: " << instruction->ToString(); continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 07cd1efc12..3e0def5d26 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloVerifiedTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral<F32>( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { root->literal().EachCell<NativeT>( [&](absl::Span<const int64> indices, NativeT value) { std::vector<int64> rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get<NativeT>(rindexes)); + matched = matched && (value == literal_clone.Get<NativeT>(rindexes)); }); EXPECT_TRUE(matched); } @@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(kConstantFoldReduce)); + ParseAndVerifyModule(kConstantFoldReduce); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_TRUE(result); - EXPECT_EQ(6, module->entry_computation() + EXPECT_EQ(6, module() + .entry_computation() ->root_instruction() ->literal() .GetFirstElement<int32>()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(kConstantFoldReduce)); - HloInstruction* add = module->computations().begin()->root_instruction(); + ParseAndVerifyModule(kConstantFoldReduce); + HloInstruction* add = module().computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_FALSE(result); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index a3fcc0fefa..b76c50bb5b 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique<Literal>( - LiteralUtil::Zero(operand->shape().element_type())))); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(operand->shape().element_type()))); return MakePadHlo(operand, zero, padding_config); } StatusOr<HloInstruction*> BroadcastZeros( HloComputation* computation, PrimitiveType element_type, absl::Span<const int64> broadcast_dimensions) { - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique<Literal>(LiteralUtil::Zero(element_type)))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index eb6affadc8..e07a196d11 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate<Literal>( *module, {LiteralUtil::CreateR1<int32>({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( + Literal result_literal, + evaluator.Evaluate<Literal>( *module, {LiteralUtil::CreateR3<int32>( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2<int32>( + CHECK_EQ(result_literal, + LiteralUtil::CreateR2<int32>( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( - *module, {LiteralUtil::CreateR1<int32>({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}})); + Literal result_literal, + evaluator.Evaluate<Literal>(*module, + {LiteralUtil::CreateR1<int32>({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( - *module, {LiteralUtil::CreateR1<int32>({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}})); + Literal result_literal, + evaluator.Evaluate<Literal>(*module, + {LiteralUtil::CreateR1<int32>({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( - *module, {LiteralUtil::CreateR0<int32>(9)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( + Literal result_literal, + evaluator.Evaluate<Literal>( *module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate<Literal>( *module, {LiteralUtil::CreateR1<int32>({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( - *module, {LiteralUtil::CreateR0<int32>(0)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate<Literal>( *module, {LiteralUtil::CreateR0<float>(0.0f)})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index e09d5868f2..9b18b0284f 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0<float>(84.0); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index d0d955fea8..06b6d5b559 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -54,9 +54,8 @@ namespace xla { namespace { template <typename OperandT> -StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, - LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function<bool(OperandT, OperandT)> compare_op; switch (opcode) { case HloOpcode::kEq: @@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<bool>([&](absl::Span<const int64> multi_index) { + result.Populate<bool>([&](absl::Span<const int64> multi_index) { return compare_op(lhs_literal.Get<OperandT>(multi_index), rhs_literal.Get<OperandT>(multi_index)); })); @@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr<std::unique_ptr<Literal>> Compare<complex64>( - const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function<bool(complex64, complex64)> compare_op; switch (opcode) { case HloOpcode::kEq: @@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>( << HloOpcodeString(opcode); } - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<bool>([&](absl::Span<const int64> multi_index) { + result.Populate<bool>([&](absl::Span<const int64> multi_index) { return compare_op(lhs_literal.Get<complex64>(multi_index), rhs_literal.Get<complex64>(multi_index)); })); @@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) } template <typename LiteralPtr> -StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( +StatusOr<Literal> HloEvaluator::Evaluate( const HloModule& module, absl::Span<const LiteralPtr> arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); @@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .CloneToUnique(); + .Clone(); +} + +template <> +StatusOr<Literal> HloEvaluator::Evaluate<Literal>( + const HloModule& module, absl::Span<const Literal> arg_literals) { + std::vector<const Literal*> arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate<const Literal*>(module, arg_literal_ptrs); } template <typename LiteralPtr> -StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( +StatusOr<Literal> HloEvaluator::Evaluate( const HloComputation& computation, absl::Span<const LiteralPtr> arg_literals) { CHECK(computation.parent() != nullptr); @@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( } TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); +} + +template <> +StatusOr<Literal> HloEvaluator::Evaluate<Literal>( + const HloComputation& computation, absl::Span<const Literal> arg_literals) { + std::vector<const Literal*> arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate<const Literal*>(computation, arg_literal_ptrs); } template <typename LiteralPtr> -StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( +StatusOr<Literal> HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); @@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = input_literal->CloneToUnique(); + evaluated_[operand] = input_literal->Clone(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); +} + +template <> +StatusOr<Literal> HloEvaluator::Evaluate<Literal>( + HloInstruction* instruction, absl::Span<const Literal> arg_literals) { + std::vector<const Literal*> arg_literal_ptrs; + for (const auto& literal : arg_literals) { + arg_literal_ptrs.push_back(&literal); + } + return Evaluate<const Literal*>(instruction, arg_literal_ptrs); } -StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( - HloInstruction* instruction) { +StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kParameter) { return tensorflow::errors::FailedPrecondition( "Cannot evaluate a parameter."); @@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -std::unique_ptr<Literal> HloEvaluator::TryEvaluate( - HloInstruction* instruction) { +bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { + CHECK(result != nullptr); auto result_or = Evaluate(instruction); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); - return nullptr; + return false; } - return result_or.ConsumeValueOrDie(); + *result = result_or.ConsumeValueOrDie(); + return true; } -StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions( +StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map<const HloInstruction*, const Literal*>& substitutions) { @@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions( owned_operands.push_back(operand->Clone()); } else { owned_operands.push_back( - HloInstruction::CreateConstant(it->second->CloneToUnique())); + HloInstruction::CreateConstant(it->second->Clone())); } } @@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp( +StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr<HloInstruction> lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr<HloInstruction> rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); std::unique_ptr<HloInstruction> cloned_instruction = HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), @@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp( +StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr<HloInstruction> operand_instr = - HloInstruction::CreateConstant(operand.CloneToUnique()); + HloInstruction::CreateConstant(operand.Clone()); std::unique_ptr<HloInstruction> cloned_instruction = HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); @@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp( +StatusOr<Literal> HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr<HloInstruction> lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr<HloInstruction> rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); TF_ASSIGN_OR_RETURN( Shape dot_shape, @@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = input_literal->CloneToUnique(); + evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex { // there is one) to `reshaped_start_indices`. static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices( int64 index_vector_dim, const Literal& start_indices, - std::unique_ptr<Literal>* reshaped_start_indices) { + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices( new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_start_indices, start_indices.Reshape(new_shape)); - return std::cref(**reshaped_start_indices); + return std::cref(*reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { - std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape()); + Literal result = Literal::CreateFromShape(gather->shape()); const Shape& shape = gather->shape(); const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr<Literal> reshaped_start_indices; + Literal reshaped_start_indices; TF_ASSIGN_OR_RETURN( const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), @@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( - result->CopyElementFrom(operand, input_index, output_index)); + result.CopyElementFrom(operand, input_index, output_index)); return true; }; @@ -940,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand.shape().dimensions(i)); + auto operand_dim_size = operand.shape().dimensions(i); + auto broadcast_dim_size = + broadcast->shape().dimensions(broadcast->dimensions(i)); + TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( + "Operand dimension %d is broadcast to output dimension %d, but the " + "sizes of these two dims do not match (%d vs %d): %s", + i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, + broadcast->ToString()); } TF_ASSIGN_OR_RETURN( @@ -971,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = absl::make_unique<Literal>( - ShapeUtil::GetTupleElementShape(operand->shape(), index)); - return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, - /*dest_shape_index=*/{}, - /*src_shape_index=*/{index}); + evaluated_[get_tuple_element] = + Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - - auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); - evaluated_[copy] = std::move(result); + evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); return Status::OK(); } @@ -998,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - std::unique_ptr<Literal> result = + Literal result = embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals) .ConsumeValueOrDie(); @@ -1030,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; - std::unique_ptr<Literal> result = + Literal result = embedded_evaluator .Evaluate<const Literal*>(*readded_computation, arg_literals) .ConsumeValueOrDie(); @@ -1050,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; - std::unique_ptr<Literal> result; + Literal result; if (pred.Get<bool>({})) { result = embedded_evaluator .Evaluate<const Literal*>(*true_computation, @@ -1075,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get<bool>({})) { - evaluated_[select] = on_true.CloneToUnique(); + evaluated_[select] = on_true.Clone(); } else { - evaluated_[select] = on_false.CloneToUnique(); + evaluated_[select] = on_false.Clone(); } return Status::OK(); } @@ -1091,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); if (pred.Get<bool>({})) { - evaluated_[tuple_select] = on_true.CloneToUnique(); + evaluated_[tuple_select] = on_true.Clone(); } else { - evaluated_[tuple_select] = on_false.CloneToUnique(); + evaluated_[tuple_select] = on_false.Clone(); } return Status::OK(); } @@ -1102,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); // Initialize the loop carried valued with the input to the While instruction. - auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); @@ -1112,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } - TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>( - *cond_comp, {lcv.get()})); - keep_going = cond_val->GetFirstElement<bool>(); + TF_ASSIGN_OR_RETURN(auto cond_val, + cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv})); + keep_going = cond_val.GetFirstElement<bool>(); if (keep_going) { TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>( - *body_comp, {lcv.get()})); - VLOG(3) << "Loop iteration result: " << body_val->ToString(); + *body_comp, {&lcv})); + VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); loop_body_evaluator.ResetVisitStates(); @@ -1133,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { // hoops to make this work. namespace { template <typename KeyType, typename ValueType> -StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { auto rank = ShapeUtil::Rank(keys_literal.shape()); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) @@ -1173,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape()); - result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys)); - auto result_values_literal = - absl::make_unique<Literal>(values_literal.shape()); - result_values_literal->PopulateR1( + Literal result_keys_literal(keys_literal.shape()); + result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys)); + Literal result_values_literal(values_literal.shape()); + result_values_literal.PopulateR1( absl::Span<const ValueType>(result_values)); return std::make_pair(std::move(result_keys_literal), std::move(result_values_literal)); }; - std::unique_ptr<Literal> result_tuple; + Literal result_tuple; if (rank == 1) { auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = LiteralUtil::MakeTuple( - {result_pair.first.get(), result_pair.second.get()}); + result_tuple = + LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape()); - auto values_result_literal = - absl::make_unique<Literal>(values_literal.shape()); + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); + .Reshape({r1_length})); TF_ASSIGN_OR_RETURN(auto values_r1_slice, values_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); + .Reshape({r1_length})); + auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first->Reshape({1, r1_length})); + r1_result_pair.first.Reshape({1, r1_length})); TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( - *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( - *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); + r1_result_pair.second.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values, {0, 0}, {row, 0}, {1, r1_length})); } - result_tuple = LiteralUtil::MakeTuple( - {keys_result_literal.get(), values_result_literal.get()}); + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); } - VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } template <typename KeyType> -StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { case F32: return EvaluateSortInternal<KeyType, float>(sort, keys_literal, @@ -1242,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried( } } -StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { +StatusOr<Literal> EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(0)->shape().element_type()) { case F32: return EvaluateSortCurried<float>(sort, keys_literal, values_literal); @@ -1308,33 +1339,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr<std::unique_ptr<Literal>> -HloEvaluator::Evaluate<const Literal*>( +template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>( const HloModule& module, absl::Span<const Literal* const> arg_literals); -template StatusOr<std::unique_ptr<Literal>> -HloEvaluator::Evaluate<std::unique_ptr<Literal>>( - const HloModule& module, - absl::Span<const std::unique_ptr<Literal>> arg_literals); - -template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - absl::Span<const Literal* const> arg_literals); -template StatusOr<std::unique_ptr<Literal>> -HloEvaluator::Evaluate<std::unique_ptr<Literal>>( + +template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>( const HloComputation& computation, - absl::Span<const std::unique_ptr<Literal>> arg_literals); + absl::Span<const Literal* const> arg_literals); -template StatusOr<std::unique_ptr<Literal>> -HloEvaluator::Evaluate<const Literal*>( +template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>( HloInstruction* instruction, absl::Span<const Literal* const> arg_literals); -template StatusOr<std::unique_ptr<Literal>> -HloEvaluator::Evaluate<std::unique_ptr<Literal>>( - HloInstruction* instruction, - absl::Span<const std::unique_ptr<Literal>> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 72252bafc7..21e676d671 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template <typename LiteralPtr> - StatusOr<std::unique_ptr<Literal>> Evaluate( - const HloModule& module, absl::Span<const LiteralPtr> arg_literals); + StatusOr<Literal> Evaluate(const HloModule& module, + absl::Span<const LiteralPtr> arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template <typename LiteralPtr> - StatusOr<std::unique_ptr<Literal>> Evaluate( - const HloComputation& computation, - absl::Span<const LiteralPtr> arg_literals); + StatusOr<Literal> Evaluate(const HloComputation& computation, + absl::Span<const LiteralPtr> arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template <typename LiteralPtr> - StatusOr<std::unique_ptr<Literal>> Evaluate( - HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals); + StatusOr<Literal> Evaluate(HloInstruction* instruction, + absl::Span<const LiteralPtr> arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. // Precondition: // 1. all operands of the input instruction are constants. // 2. the instruction is not a Parameter operation. - StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction); + StatusOr<Literal> Evaluate(HloInstruction* instruction); - // Same as Evaluate, except returning nullptr on error. - std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction); + // Same as Evaluate, except returning false on error and accepts an output + // pointer. + bool TryEvaluate(HloInstruction* instruction, Literal* result); // Evaluates a single HLO instruction, substituting the given literals for // some of the instruction's operands. // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions( + StatusOr<Literal> EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map<const HloInstruction*, const Literal*>& substitutions); - StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp( - HloOpcode opcode, const Literal& lhs, const Literal& rhs); + StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp( - HloOpcode opcode, const Literal& operand); + StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr<std::unique_ptr<Literal>> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, const Literal& lhs, - const Literal& rhs); + StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this @@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); + return it->second; } // Tracks the HLO instruction and its evaluated literal result. @@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // that are no longer a parent for any other subsequent instruction in // post-orderring. // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>> - evaluated_; + // Storing Literal in place require the container to have pointer stability so + // we cannot use FlatMap any more. + std::unordered_map<const HloInstruction*, Literal> evaluated_; private: template <typename ReturnT, typename NativeT> - static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl( + static StatusOr<Literal> ElementWiseUnaryOpImpl( HloInstruction* instruction, const std::function<ReturnT(NativeT)>& unary_op, const Literal& operand_literal) { @@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape())); } - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { return unary_op(operand_literal.Get<NativeT>(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 102ebb24ab..01e88566a5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, evaluator_ = absl::make_unique<HloEvaluator>(); } - std::unique_ptr<Literal> Evaluate( - absl::Span<const Literal* const> arg_literals = {}) { + Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, std::unique_ptr<HloEvaluator> evaluator_; - void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected, - std::unique_ptr<Literal> input, float aabs = 0) { + void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, + float aabs = 0) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - b.AddInstruction( - HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); - auto element_type = expected->shape().element_type(); + auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } } - void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected, - std::unique_ptr<Literal> lhs, - std::unique_ptr<Literal> rhs) { + void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs, + Literal rhs) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( - HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } bool use_bfloat16_; @@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}}); - Shape shape = low->shape(); + Shape shape = low.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0<float>(1.f); - Shape shape = value->shape(); + Shape shape = value.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}}); - Shape shape = on_true->shape(); + Shape shape = on_true.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred))); auto c2 = @@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate({}); + Literal result = Evaluate({}); auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}}); - std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()}; + std::vector<const Literal*> args = {&lhs, &rhs, &rhs2}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(args); + Literal result = Evaluate(args); auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies Reshape operation is correctly evaluated. @@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral<F32>( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); @@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate({}); + Literal result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type; - result->EachCell<NativeT>( - [&](absl::Span<const int64> indices, NativeT value) { - std::vector<int64> rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250); - }); + result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) { + std::vector<int64> rindexes = Permute(permutation, indices); + EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250); + }); } // Verifies Broadcast operation is correctly evaluated. @@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, {1, 2})); + output_literal.shape(), literal_instruction, {1, 2})); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloInstruction::CreateConstant(std::move(input_literal))); // Broadcast dimension should be empty in the case of scalars. b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, + output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<int64>( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1<int64>({100, 200}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}); auto expected = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); - ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); auto expected = LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); - ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } PaddingConfig CreatePaddingConfig( @@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<int32>( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique<Array2D<float>>(1, 5); @@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique<Array2D<float>>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected_array = Array2D<float>({ @@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected_array = Array2D<float>({ {22.f, 28.f}, @@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); Array4D<float> expected_array(1, 1, 4, 4); // clang-format off @@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = LiteralUtil::CreateR4FromArray4D<float>( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = LiteralUtil::CreateR4FromArray4D<float>( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); Array4D<float> expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D<float>({ @@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); Array4D<float> expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D<float>({ @@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, @@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest, window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); Array4D<float> expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D<float>({ @@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { @@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape)); std::iota(input_elems.begin(), input_elems.end(), -7); auto input_r1 = LiteralUtil::CreateR1<float>(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape)); std::iota(filter_elems.begin(), filter_elems.end(), -31); auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); @@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); Array4D<float> expected_array(1, 1, 1, 8); expected_array.FillWithYX( Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}})); auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { module().AddEntryComputation(b.Build()); HloEvaluator hlo_eval; - std::unique_ptr<Literal> result = - hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result); + Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result); } // Reducing many numbers should be fast because it doesn't create @@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1<float>({6, 18}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({{6, 7}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector<int64> input_dims(6, 4); - std::unique_ptr<Literal> arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f); HloInstruction* arg_instruction = @@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4}; - std::unique_ptr<Literal> result_literal = + Literal result_literal = LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f); - EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*strides=*/{2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({ {3}, {19}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<float>({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { shape, operand, update, start_indices)); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<double>({ {1, -2, -3}, {5, -6, -7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2<double>({ {1, 2, 3}, {5, 6, 7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D<double>(*operand_array); - auto expected = LiteralUtil::MakeTuple({ - result_inner_literal.get(), - result_inner_literal.get(), - }); + auto expected = + LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); module().AddEntryComputation(b.Build()); - std::unique_ptr<Literal> result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D<float>({ @@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; + Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4}); + Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40}); auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()}, - {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}}); + add, {{param0, ¶m0_literal}, {square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; - auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}}); + Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40}); + auto result = + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1906,12 +1901,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1930,12 +1925,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1954,14 +1949,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> start_indices = - LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); + Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3<int32>( + LiteralUtil::CreateR3<int32>( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), start_indices.get()}))); + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1980,15 +1974,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> start_indices = - LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, @@ -2008,15 +2001,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> start_indices = - LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -2035,12 +2027,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -2059,13 +2050,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> start_indices = - LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); + Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2084,11 +2074,10 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); - std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2108,12 +2097,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); - std::unique_ptr<Literal> start_indices = + Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); + Literal start_indices = LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2138,15 +2127,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2171,15 +2158,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2205,15 +2191,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2239,15 +2223,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { @@ -2273,17 +2255,15 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>( + Literal operand = LiteralUtil::CreateR2<float>( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({2, 1}); - std::unique_ptr<Literal> updates = + Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1}); + Literal updates = LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2<float>( + LiteralUtil::CreateR2<float>( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), - ErrorSpec{0.1, 0.01})); + Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2309,15 +2289,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({1, 1}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2343,15 +2321,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3<int32>( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2376,21 +2353,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}}); - std::unique_ptr<Literal> expected = + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}}); + Literal expected = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, @@ -2416,21 +2390,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}}); - std::unique_ptr<Literal> expected = + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}}); + Literal expected = LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2455,16 +2426,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({1, 1}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}}); - std::unique_ptr<Literal> expected = + Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1}); + Literal updates = LiteralUtil::CreateR2<int32>({{10}}); + Literal expected = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2489,17 +2458,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = + Literal operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR3<int32>({{{10}}, {{20}}}); - std::unique_ptr<Literal> expected = + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}}); + Literal expected = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2524,13 +2490,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); - std::unique_ptr<Literal> scatter_indices = - LiteralUtil::CreateR1<int32>({0, 2}); - std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}}); + Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{}, {}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *operand, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + operand, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2557,16 +2521,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); - std::unique_ptr<Literal> scatter_indices = + Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr<Literal> updates = - LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}}); - std::unique_ptr<Literal> expected = - LiteralUtil::CreateR1<int32>({10, 61, 32}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}}); + Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2603,11 +2564,29 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>( + Literal arg = LiteralUtil::CreateR1<bfloat16>( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); - std::unique_ptr<Literal> expected = - LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); + Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); +} + +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout<float>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 63303aef1e..8fb17a0033 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = absl::make_unique<Literal>(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> out_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> out_index) { std::vector<int64> from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1157,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast<ReturnT>(result_val); }; - auto result = absl::make_unique<Literal>(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1231,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = absl::make_unique<Literal>(dot->shape()); + Literal result(dot->shape()); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> result_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> result_index) { ElementwiseT result_val = static_cast<ElementwiseT>(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1280,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); - auto result = absl::make_unique<Literal>(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate<ReturnT>( [&scalar](absl::Span<const int64> multi_index) { return scalar; })); const Literal& evaluated_operand = @@ -1289,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1311,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set<ReturnT>(target_index, - evaluated_operand.Get<ReturnT>(input_index)); + result.Set<ReturnT>(target_index, + evaluated_operand.Get<ReturnT>(input_index)); return true; }; @@ -1439,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename NativeT> - StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) { + StatusOr<Literal> MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = absl::make_unique<Literal>(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { - std::vector<std::unique_ptr<Literal>> arg_literals; + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + std::vector<Literal> arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1463,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr<Literal> computed_result = - embedded_evaluator - .Evaluate<std::unique_ptr<Literal>>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate<Literal>(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get<ReturnT>({}); + return computed_result.Get<ReturnT>({}); })); return std::move(result); } @@ -1557,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess<ReturnT>(a, b); }); - auto result_literal = absl::make_unique<Literal>(keys_literal.shape()); - result_literal->PopulateR1(absl::Span<const ReturnT>(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + Literal result_literal(keys_literal.shape()); + result_literal.PopulateR1(absl::Span<const ReturnT>(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); return result_literal; }; @@ -1568,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = absl::make_unique<Literal>(keys_literal.shape()); + Literal result_literal(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + .Reshape({r1_length})); + auto r1_result = sort_r1(r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + r1_result, {0, 0}, {row, 0}, {1, r1_length})); } parent_->evaluated_[sort] = std::move(result_literal); } @@ -1651,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args); + absl::InlinedVector<Literal, 1> results(num_args); for (int64 i = 0; i < num_args; ++i) { - results[i] = absl::make_unique<Literal>(result_shape); + results[i] = Literal(result_shape); } Status eval_status; @@ -1667,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>( + TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>( [&](absl::Span<const int64> multi_index) { if (!eval_status.ok()) { return init_scalars[input]; @@ -1703,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Evaluate computation with specified literal operands. - absl::InlinedVector<std::unique_ptr<Literal>, 1> - embedded_operands; + absl::InlinedVector<Literal, 1> embedded_operands; for (ReturnT value : result_values) { embedded_operands.push_back( LiteralUtil::CreateR0<ReturnT>(value)); @@ -1717,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_operands.size()); std::transform(embedded_operands.begin(), embedded_operands.end(), embedded_operands_ptrs.begin(), - [](const std::unique_ptr<Literal>& ptr) { - return ptr.get(); - }); + [](Literal& literal) { return &literal; }); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result, + TF_ASSIGN_OR_RETURN(Literal computed_result, embedded_evaluator.Evaluate<const Literal*>( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on @@ -1729,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. if (!has_tuple_output) { - result_values[0] = computed_result->Get<ReturnT>({}); + result_values[0] = computed_result.Get<ReturnT>({}); } else { for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result->Get<ReturnT>( + result_values[i] = computed_result.Get<ReturnT>( /*multi_index=*/{}, /*shape_index=*/{i}); } } @@ -1748,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!has_tuple_output) { parent_->evaluated_[reduce] = std::move(results[0]); } else { - auto tuple_result = absl::make_unique<Literal>(reduce->shape()); + Literal tuple_result(reduce->shape()); for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); } parent_->evaluated_[reduce] = std::move(tuple_result); } @@ -1781,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get<ReturnT>({}); - auto result = absl::make_unique<Literal>(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + TF_RETURN_IF_ERROR(result.Populate<ReturnT>( [&](absl::Span<const int64> output_index) { return init_scalar; })); std::vector<int64> window_dimension_sizes; @@ -1834,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr<Literal> computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate<const Literal*>( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get<bool>({}); + bool selected = !computed_result.Get<bool>({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1856,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get<ReturnT>(source_index); - auto scattered = result->Get<ReturnT>(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr<Literal> computed_result = + auto scattered = result.Get<ReturnT>(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate<const Literal*>(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate<const Literal*>( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get<ReturnT>({})); + result.Set(operand_index, computed_result.Get<ReturnT>({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); @@ -1916,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique<Literal>(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> output_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1935,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0<ReturnT>(curr_val); const auto result_val_literal = LiteralUtil::CreateR0<ReturnT>(result_val); - std::unique_ptr<Literal> computed_result = + Literal computed_result = embedded_evaluator .Evaluate<const Literal*>( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get<ReturnT>({}); + result_val = computed_result.Get<ReturnT>({}); }); return result_val; @@ -1961,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // literal (if there is one) to `reshaped_indices`. StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices( int64 index_vector_dim, const Literal& indices, - std::unique_ptr<Literal>* reshaped_indices) { + Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { return std::cref(indices); } @@ -1970,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { indices.shape().dimensions().end()); new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); - return std::cref(**reshaped_indices); + return std::cref(*reshaped_indices); } // Returns an ShapeUtil::IndexIterationSpace that iterates over the update @@ -2230,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scatter->scatter_dimension_numbers(); const Literal& operand = parent_->GetEvaluatedLiteralFor(scatter->operand(0)); - std::unique_ptr<Literal> reshaped_scatter_indices; + Literal reshaped_scatter_indices; TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, ReshapedScatterIndices(dim_numbers.index_vector_dim(), parent_->GetEvaluatedLiteralFor( @@ -2260,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize the result with the operand. This makes it easier to handle // the updates even when the indices are repeated. - std::unique_ptr<Literal> result = operand.CloneToUnique(); + Literal result = operand.Clone(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = [&](absl::Span<const int64> update_window_index, @@ -2299,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result_value_literal = - LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index)); + LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index)); auto update_value_literal = LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index)); - std::unique_ptr<Literal> updated_result = + Literal updated_result = embedded_evaluator .Evaluate<const Literal*>( *scatter->to_apply(), - {result_value_literal.get(), update_value_literal.get()}) + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. embedded_evaluator.ResetVisitStates(); - result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({})); + result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({})); return true; }; @@ -2359,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get<ReturnT>(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); + Literal result(shape); + TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2575,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (ShapeUtil::Rank(iota->shape()) > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], - result->Broadcast(iota->shape(), {iota->iota_dimension()})); + result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); parent_->evaluated_[iota] = std::move(result); @@ -2645,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename IndexT> - StatusOr<std::unique_ptr<Literal>> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr<Literal> DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data<IndexT>(); std::vector<int64> start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2660,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector<int64> operand_indices(start.size()); - auto result = absl::make_unique<Literal>(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2676,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename IndexT> - StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data<IndexT>(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector<int64> start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2689,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min<int64>( std::max<int64>(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector<int64> result_index(rank, 0); auto func = [&](absl::Span<const int64> update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus<int64>()); - result->Set<ReturnT>(result_index, - update_literal.Get<ReturnT>(update_index)); + result.Set<ReturnT>(result_index, + update_literal.Get<ReturnT>(update_index)); return true; }; @@ -2710,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp( + StatusOr<Literal> ElementWiseUnaryOp( HloInstruction* instruction, const std::function<ElementwiseT(ElementwiseT)>& unary_op) { const Literal& operand_literal = @@ -2723,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp( + StatusOr<Literal> ElementWiseBinaryOp( HloInstruction* instruction, const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& binary_op) { @@ -2745,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get<ReturnT>(multi_index), rhs_literal.Get<ReturnT>(multi_index)); @@ -2757,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename LhsType, typename RhsType, typename EhsType> - StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp( + StatusOr<Literal> ElementwiseTernaryOp( HloInstruction* instruction, const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) { const auto shape = instruction->shape(); @@ -2782,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = absl::make_unique<Literal>(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) { + result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) { return ternary_op(lhs_literal.Get<LhsType>(multi_index), rhs_literal.Get<RhsType>(multi_index), ehs_literal.Get<EhsType>(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 0345a2a5f8..287ba84b3b 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -123,6 +123,10 @@ class NodeFilter { // We arbitrarily set this as the boundary between "large" and "small" // instructions. bool IsSmall(const HloInstruction* instr) { + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { + return true; + } return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; } @@ -465,9 +469,8 @@ stylesheet=< string graph_label = StrCat(label_, "<br/>Computation ", computation_->name()); if (computation_->IsFusionComputation()) { - StrAppend(&graph_label, - StrCat(" (in fusion instruction ", - computation_->FusionInstruction()->name(), ")")); + StrAppend(&graph_label, " (in fusion instruction ", + computation_->FusionInstruction()->name(), ")"); } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 25ae344ea5..e905f2983a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); - instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } case HloOpcode::kFusion: { @@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -527,7 +528,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant( - std::unique_ptr<Literal> literal) { + Literal literal) { return absl::make_unique<HloConstantInstruction>(std::move(literal)); } @@ -2096,7 +2097,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", [&](string* out, HloInstruction* pre) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5581c17c2d..4f6cac1396 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -82,6 +82,7 @@ class HloPrintOptions { print_operand_shape_(true), print_program_shape_(true), print_percent_(true), + print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} @@ -94,7 +95,8 @@ class HloPrintOptions { .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) - .set_print_percent(false); + .set_print_percent(false) + .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic @@ -108,6 +110,7 @@ class HloPrintOptions { .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) + .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } @@ -153,6 +156,12 @@ class HloPrintOptions { return *this; } + // If true, control dependencies will be printed. + HloPrintOptions& set_print_control_dependencies(bool value) { + print_control_dependencies_ = value; + return *this; + } + // If true, only a part of operands will be printed out, and their names will // be omitted (note that in this case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { @@ -190,6 +199,9 @@ class HloPrintOptions { bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool print_control_dependencies() const { + return print_control_dependencies_; + } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } @@ -205,6 +217,7 @@ class HloPrintOptions { bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; @@ -346,8 +359,7 @@ class HloInstruction { const string& name); // Creates a literal constant instruction. - static std::unique_ptr<HloInstruction> CreateConstant( - std::unique_ptr<Literal> literal); + static std::unique_ptr<HloInstruction> CreateConstant(Literal literal); // Creates an Iota instruction. static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index fb7345a2ad..e92882c22a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } -HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal) - : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), +HloConstantInstruction::HloConstantInstruction(Literal literal) + : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} HloConstantInstruction::HloConstantInstruction(const Shape& shape) @@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - if (literal_ != nullptr) { + if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } return proto; @@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, if (!mutable_array_subshape->has_layout() || !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); + *literal_ = literal_->Relayout(new_layout, shape_index); *mutable_array_subshape->mutable_layout() = new_layout; } } @@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction> HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const { - return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique()); + CHECK(literal_.has_value()); + return absl::make_unique<HloConstantInstruction>(literal_->Clone()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if (literal_ != nullptr && + if (literal_.has_value() && ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple @@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstructionProto HloTraceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + *proto.mutable_literal() = literal_.ToProto(); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index c3a7801164..2d7bc83855 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: - explicit HloConstantInstruction(std::unique_ptr<Literal> literal); + explicit HloConstantInstruction(Literal literal); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } // Returns whether there is literal associated with this instruction. - bool HasLiteral() const { return literal_ != nullptr; } + bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction { std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr<Literal> literal_; + absl::optional<Literal> literal_; }; class HloTraceInstruction : public HloInstruction { public: explicit HloTraceInstruction(const string& tag, HloInstruction* operand); // Returns a tag to be used in tracing. - string TracingTag() const { return literal_->GetR1U8AsString(); } + string TracingTag() const { return literal_.GetR1U8AsString(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction { std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr<Literal> literal_; + Literal literal_; }; class HloFusionInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 9bfb0af96c..c7ec88d450 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include <map> #include <queue> @@ -582,4 +582,22 @@ StatusOr<HloInstructionSequence> ScheduleComputation( size_function, nullptr, empty_map); } +HloMemoryScheduler::HloMemoryScheduler( + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) + : size_function_(size_function), algorithm_(algorithm) {} + +StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + +StatusOr<bool> HloDescheduler::Run(HloModule* module) { + bool changed = module->has_schedule(); + module->clear_schedule(); + return changed; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 54e32340ba..5e02868eba 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #include <vector> #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -86,6 +87,37 @@ StatusOr<HloInstructionSequence> ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloPassInterface { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + StatusOr<bool> Run(HloModule* module) override; + + private: + LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModudle::has_schedule will return false. +class HloDescheduler : public HloPassInterface { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + StatusOr<bool> Run(HloModule* module) override; +}; + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 6afe51997e..1b9e9bfc77 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include <memory> #include <string> @@ -67,22 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); + HloMemoryScheduler scheduler([](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + // Verify that all instructions are in the sequence. const std::vector<const HloInstruction*>& sequence = - schedule.sequence(module->entry_computation()).instructions(); + module->schedule().sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". EXPECT_EQ(param, sequence.front()); EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering ordering(module->schedule()); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); + + // Clear the schedule using the descheduling pass. + HloDescheduler descheduler; + EXPECT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed, + descheduler.Run(module.get())); + EXPECT_TRUE(descheduler_changed); + EXPECT_FALSE(module->has_schedule()); } TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cfe906d9c5..b3949f3a6d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) { HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal( } } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr<HloComputation> computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr<HloComputation> computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. tensorflow::gtl::FlatSet<string> computation_names; tensorflow::gtl::FlatSet<string> instruction_names; + tensorflow::gtl::FlatSet<int> computation_ids; + tensorflow::gtl::FlatSet<int> instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 26fd1b2438..3bc2d13781 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -253,7 +253,7 @@ class HloModule { private: HloComputation* AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, - bool uniquify_names); + bool uniquify_identifiers); const string name_; HloModuleConfig config_; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 98d20315e3..f7be5cae22 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -36,23 +36,6 @@ namespace xla { namespace { -bool HasSendRecv(HloComputation* computation) { - for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kSendDone || - instruction->opcode() == HloOpcode::kRecv || - instruction->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (auto* sub_computation : instruction->called_computations()) { - if (HasSendRecv(sub_computation)) { - return true; - } - } - } - return false; -} - StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { bool changed = false; for (auto* computation : module->computations()) { @@ -68,9 +51,10 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { if (!ShapeUtil::IsTuple(xla_while->shape()) || while_body_root->opcode() != HloOpcode::kTuple || - HasSendRecv(while_body_comp)) { + while_body_comp->HasSideEffect() || + xla_while->while_condition()->HasSideEffect()) { // Only run DCE on tuple-shaped while loops where body root is Tuple, - // with no send/recv instructions. + // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); continue; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 363862e490..bf66cc6bc3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { "while.2", 1)); } +// Tests that a while whose body has outfeed operations is not DCE-ed. +TEST_F(HloModuleDceTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + +// Tests that if a loop variable is not referenced outside of a kWhile, the loop +// variable changes are not elided within the loop body, if the condition +// computation uses them. +TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { + auto module = ParseHloString(R"( + HloModule InfiniteLoop + WhileBody { + body_param = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2) + } + WhileCondition { + cond_param = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + p0 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(p0), index=0 + constant.3 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5) + while = (s32[], s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc new file mode 100644 index 0000000000..f9b56ef464 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +namespace xla { + +HloModuleGroup::HloModuleGroup(absl::string_view name, + std::unique_ptr<HloModule> module) + : name_(name) { + push_back(std::move(module)); +} + +HloModuleGroup::HloModuleGroup(absl::string_view name, + absl::Span<std::unique_ptr<HloModule>> modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + +std::vector<std::unique_ptr<HloModule>> HloModuleGroup::ConsumeModules() { + std::vector<std::unique_ptr<HloModule>> ret_modules = std::move(modules_); + + // Clear everything so the object state is in a known (empty) state. + modules_.clear(); + module_ptrs_.clear(); + return ret_modules; +} + +string HloModuleGroup::ToString() const { + std::ostringstream s; + s << "HloModuleGroup " << name() << "\n\n"; + for (const HloModule* module : modules()) { + s << module->ToString() << "\n"; + } + return s.str(); +} + +HloModuleGroupProto HloModuleGroup::ToProto() const { + HloModuleGroupProto proto; + proto.set_name(name()); + for (const HloModule* module : modules()) { + *proto.add_hlo_modules() = module->ToProto(); + } + return proto; +} + +/* static */ StatusOr<HloModuleGroup> HloModuleGroup::CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span<const HloModuleConfig> module_configs) { + TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; + TF_RET_CHECK(proto.hlo_modules_size() > 0) + << "Module group must have at least one HLO module"; + TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size()); + + std::vector<std::unique_ptr<HloModule>> modules; + for (int i = 0; i < proto.hlo_modules_size(); ++i) { + const HloModuleProto& module_proto = proto.hlo_modules(i); + TF_ASSIGN_OR_RETURN( + std::unique_ptr<HloModule> module, + HloModule::CreateFromProto(module_proto, module_configs[i])); + modules.push_back(std::move(module)); + } + + return HloModuleGroup(proto.name(), absl::MakeSpan(modules)); +} + +void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) { + modules_.push_back(std::move(module)); + module_ptrs_.push_back(modules_.back().get()); +} + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { + out << group.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h new file mode 100644 index 0000000000..7338be8b9c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ + +#include <iosfwd> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// An abstraction representing a ordered set of HLO module built to run +// concurrently across different devices. +class HloModuleGroup { + public: + // Construct an empty module group. + explicit HloModuleGroup(absl::string_view name) : name_(name) {} + + // Construct a module group containing a single module. + HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module); + + // Construct a module group containing any number of modules. + HloModuleGroup(absl::string_view name, + absl::Span<std::unique_ptr<HloModule>> modules); + + // Returns the modules contained in the group. + const std::vector<HloModule*>& modules() const { return module_ptrs_; } + + // Returns a module at a particular index. + HloModule& module(int index) const { return *module_ptrs_.at(index); } + + // Add a module to the back of vector of modules in the group. + void push_back(std::unique_ptr<HloModule> module); + + // Moves all modules from the group into the returned vector. After this + // method runs, the module group will be empty. + std::vector<std::unique_ptr<HloModule>> ConsumeModules(); + + string name() const { return name_; } + string ToString() const; + + // Serialize the module group to/from a proto. + HloModuleGroupProto ToProto() const; + static StatusOr<HloModuleGroup> CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span<const HloModuleConfig> module_configs); + + private: + string name_; + + // Vector of modules as std::unique_ptrs. + std::vector<std::unique_ptr<HloModule>> modules_; + + // Vector of modules as normal pointers. This vector is kept in sync with + // modules_ as modules are added to the group with push_back. + std::vector<HloModule*> module_ptrs_; +}; + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc new file mode 100644 index 0000000000..ebf790ba6f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class HloModuleGroupTest : public HloTestBase { + protected: + HloModuleGroupTest() = default; +}; + +TEST_F(HloModuleGroupTest, SingleModule) { + const string text = R"( +HloModule simple_module + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + HloModuleGroup group(TestName(), std::move(module)); + + EXPECT_EQ(group.modules().size(), 1); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config()})); + EXPECT_EQ(group_copy.modules().size(), 1); + EXPECT_THAT( + group_copy.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules(); + EXPECT_EQ(modules.size(), 1); + EXPECT_EQ(group.modules().size(), 0); +} + +TEST_F(HloModuleGroupTest, MultipleModules) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1, + ParseHloString(text_1)); + std::vector<std::unique_ptr<HloModule>> modules; + modules.push_back(std::move(module_0)); + modules.push_back(std::move(module_1)); + HloModuleGroup group(TestName(), absl::MakeSpan(modules)); + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config(), + group.module(1).config()})); + EXPECT_EQ(group_copy.modules().size(), 2); +} + +TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1, + ParseHloString(text_1)); + HloModuleGroup group(TestName()); + group.push_back(std::move(module_0)); + group.push_back(std::move(module_1)); + + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 400bd4d947..39f38b417a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" - #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -253,6 +253,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloModuleTest, ProtoSerializationPreservesIds) { + // Verify that serializing then deserializing an HLO proto preserves the + // unique IDs of the instruction and module. + const string text = + R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + + // Perform various transformations on the graph: + // + // * clone the reduction function + // * replace use of reduction function with the clone. + // * add a random instruction to the entry computation. + // + // This will create instruction and computation IDs which are interesting: + // not consecutive and not densely packed. + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + HloComputation* reduction = root->to_apply(); + HloComputation* reduction_clone = + module->AddEmbeddedComputation(reduction->Clone()); + root->set_to_apply(reduction_clone); + TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction)); + HloInstruction* negate = entry->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root)); + entry->set_root_instruction(negate); + + // Schedule the transformed module, this verifies that the serialized schedule + // is robust against non-consecutive IDs as well (b/114712358). + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + HloMemoryScheduler scheduler(size_fn); + TF_ASSERT_OK(scheduler.Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + + // Serialize and deserialize and verify that the instruction and computations + // unique ids are the same. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + + // The module IDs should *not* be the same because module ids must be globally + // unique. + EXPECT_NE(module->unique_id(), module_copy->unique_id()); + + // Verify that the computations and instructions all have the same unique id. + auto computation_copy_it = module_copy->computations().begin(); + for (const HloComputation* computation_orig : module->computations()) { + const HloComputation* computation_copy = *computation_copy_it++; + EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id()) + << absl::StrFormat( + "ID of original computation %s != ID of deserialized " + "computation %s: %d != %d", + computation_orig->name(), computation_copy->name(), + computation_orig->unique_id(), computation_copy->unique_id()); + + auto instruction_copy_it = computation_copy->instructions().begin(); + for (const HloInstruction* instruction_orig : + computation_orig->instructions()) { + const HloInstruction* instruction_copy = *instruction_copy_it++; + EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id()) + << absl::StrFormat( + "ID of original instruction %s != ID of deserialized " + "instruction %s: %d != %d", + instruction_orig->name(), instruction_copy->name(), + instruction_orig->unique_id(), instruction_copy->unique_id()); + } + } + + // Verify that the next unique ID which the module would have handed out is + // greater than the unique id of any instruction. + int next_id = module_copy->NewUniqueInstructionId(); + for (const HloComputation* computation : module_copy->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_GT(next_id, instruction->unique_id()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 6b6005e7a5..00970bcda3 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index c54360b063..11caa89c54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -105,16 +105,13 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template <typename LiteralNativeT> - bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr<Literal> literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector<std::unique_ptr<Literal>> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal, } template <typename LiteralNativeT> -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector<tensorflow::int64> index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = absl::make_unique<Literal>(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 585c95972b..d9848cee0b 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloVerifiedTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 0a0a6a323e..bd6dd79b67 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -27,15 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -1194,51 +1193,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( return changed; } -StatusOr<bool> HloRematerialization::Run(HloModule* module, - HloSchedule* schedule, - int64 memory_limit_bytes, - RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The schedule is constructed entirely by this method. - TF_RET_CHECK(schedule->empty()); - +StatusOr<bool> HloRematerialization::Run(HloModule* module) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes); + << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial schedule of HLO instructions. - TF_ASSIGN_OR_RETURN(*schedule, - ScheduleModule(*module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(*schedule); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - - // RemoveUnnecessaryCopies only considers interference when determining - // whether it is legal to remove a copy. However, copies in the graph may be - // necessary for other reason such as preventing a constant from being live - // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. - // TODO(b/80249101): Break copy insertion into several passes and run each - // one once in the regular HLO pipeline. - TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); - - // The passes above can add and remove copies, update the schedule to - // account for these transformations. Newly added instructions will be - // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(schedule->Update()); - - TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(*schedule), module)); - } - + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1254,7 +1214,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module, }); const int64 adjusted_memory_limit_bytes = - memory_limit_bytes - module_output_size; + memory_limit_bytes_ - module_output_size; VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -1263,13 +1223,14 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module, // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, schedule](const CallGraphNode& node) -> Status { + [this, module](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory( - node.computation(), - schedule->sequence(node.computation()).instructions())); + ComputePeakMemory(node.computation(), + module->schedule() + .sequence(node.computation()) + .instructions())); } return Status::OK(); }, @@ -1287,9 +1248,10 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module, // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), schedule, - adjusted_memory_limit_bytes)); + TF_ASSIGN_OR_RETURN( + bool changed, + RematerializeComputation(module->entry_computation(), &module->schedule(), + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1298,7 +1260,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module, // After DCE, the module sequence may include instructions which no longer // exist. - TF_RETURN_IF_ERROR(schedule->Update()); + TF_RETURN_IF_ERROR(module->schedule().Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1315,32 +1277,22 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module, << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes != nullptr) { - sizes->before_bytes = before_peak_memory; - sizes->after_bytes = current_peak_memory; + if (sizes_ != nullptr) { + sizes_->before_bytes = before_peak_memory; + sizes_->after_bytes = current_peak_memory; } XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes) { + if (current_peak_memory > memory_limit_bytes_) { LOG(WARNING) << absl::StrFormat( "Can't reduce memory use below %s (%d bytes) by rematerialization; " "only reduced to %s (%d bytes)", - HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; } -/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule( - const HloRematerialization::ShapeSizeFunction& size_function, - int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, - copy_insertion); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index fa0414b472..e2aaf18b3e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,17 +17,23 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { -class HloRematerialization { +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloPassInterface { public: using ShapeSizeFunction = std::function<int64(const Shape&)>; @@ -38,10 +44,7 @@ class HloRematerialization { int64 after_bytes; }; - // Rematerialize HLO instructions in the given module to reduce peak memory - // use below memory_limit_bytes where memory use is defined as the total size - // of all live HLO instruction values. Parameters and constants are included - // in memory use estimates. Method parameters: + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. @@ -49,51 +52,27 @@ class HloRematerialization { // memory_limit_bytes: The threshold number of bytes to reduce memory use to // via rematerialization. // - // hlo_module: HLO module to rematerialize instructions in. - // - // schedule: Should point to an empty HloSchedule. Upon return - // contains the HLO instruction order which was used for - // rematerialization. This is the order in which HLO instructions should - // be emitted to minimize memory use. - // - // sizes: Optional outparam that indicates the peak memory usage of the HLO - // module before/after rematerialization. - // - // copy_insertion: If non-null, run copy elision after scheduling. This - // pass is used to eliminate copies that were inserted by copy insertion - // before HLO scheduling. - // - // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy - // insertion is integrated with HLO scheduling. - // - // Returns whether any instructions were rematerialized. If memory use is - // already below the given limit then no instructions are rematerialized and - // false is returned. - // - // CSE will undo the effects of this optimization and should not be run after - // this pass. In general, this pass should be run very late immediately before - // code generation. - static StatusOr<bool> RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - HloSchedule* schedule, RematerializationSizes* sizes, - CopyInsertion* copy_insertion = nullptr); - - protected: - HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) - : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + // sizes: Pointer to data structure which records the peak memory usage of + // the HLO module before/after rematerialization. Value are set during + // Run(). Can be nullptr. + HloRematerialization(const ShapeSizeFunction& size_function, + int64 memory_limit_bytes, RematerializationSizes* sizes) + : size_function_(size_function), + memory_limit_bytes_(memory_limit_bytes), + sizes_(sizes) {} ~HloRematerialization() {} + absl::string_view name() const override { return "rematerialization"; } + // Runs rematerialization on the given module. Returns whether the module was - // changed. memory_limit is the target maximum peak memory usage by the - // module. schedule should be an empty HloSchedule. Upon return sequence - // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr<bool> Run(HloModule* module, HloSchedule* schedule, - int64 memory_limit, RematerializationSizes* sizes, - CopyInsertion* copy_insertion); + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + StatusOr<bool> Run(HloModule* module) override; + protected: // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation @@ -121,6 +100,14 @@ class HloRematerialization { // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; + // The threshold number of bytes to reduce memory use to via + // rematerialization. + const int64 memory_limit_bytes_; + + // Pointer to data structure which records the peak memory usage of the HLO + // module before/after rematerialization + RematerializationSizes* sizes_; + // Call graph of the hlo_module. std::unique_ptr<CallGraph> call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83cb113bfb..f7e82fb1f8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase { } StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module, - HloSchedule* schedule) { + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - schedule, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &schedule)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 66ac1f66fd..fa7f216321 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice( } StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice( - const absl::Span<const std::unique_ptr<Literal>> literals) { + const absl::Span<const Literal> literals) { std::vector<const Literal*> literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { - literal_pointers.push_back(literal.get()); + literal_pointers.push_back(&literal); } return TransferLiteralsToDevice(literal_pointers); } -StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice( +StatusOr<Literal> HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( +StatusOr<Literal> HloRunner::Execute( std::unique_ptr<HloModule> module, const absl::Span<const Literal* const> arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( return TransferLiteralFromDevice(result); } -StatusOr<std::unique_ptr<Literal>> HloRunner::Execute( - std::unique_ptr<HloModule> module, - const absl::Span<const std::unique_ptr<Literal>> arguments, - bool run_hlo_passes, ExecutionProfile* profile) { +StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module, + const absl::Span<const Literal> arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector<const Literal*> argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { - argument_pointers.push_back(argument.get()); + argument_pointers.push_back(&argument); } return Execute( /*module=*/std::move(module), @@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } -StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( +StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated( std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique<Literal>(); + Literal literal; TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, literal.get())); + executor, options.outfeed_shape, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } @@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; - std::vector<std::unique_ptr<Literal>> exec_results; + std::vector<Literal> exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal, + TF_ASSIGN_OR_RETURN(Literal literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 76d8b92bed..2e934bf66a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -72,7 +72,7 @@ class HloRunner { // A pointer to a vector where the outfeed values will be stored. If // nullptr, the values will be read and discarded. - std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr; + std::vector<Literal>* outfeed_values = nullptr; // Whether the HLO passes should be run on the input module. Usually // saved modules are coming from after the HLO pass pipeline, so triggering @@ -106,24 +106,23 @@ class HloRunner { StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( const absl::Span<const Literal* const> literals); StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( - const absl::Span<const std::unique_ptr<Literal>> literals); - StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice( - const ShapedBuffer& buffer); + const absl::Span<const Literal> literals); + StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - StatusOr<std::unique_ptr<Literal>> Execute( - std::unique_ptr<HloModule> module, - const absl::Span<const Literal* const> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, + const absl::Span<const Literal* const> arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - StatusOr<std::unique_ptr<Literal>> Execute( - std::unique_ptr<HloModule> module, - const absl::Span<const std::unique_ptr<Literal>> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, + const absl::Span<const Literal> arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. @@ -140,7 +139,7 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated( + StatusOr<std::vector<Literal>> ExecuteReplicated( std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index eb52582bb5..1424569ac1 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 1e2b31a1f2..6fd734a2b9 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloTestBase { +class HloTfGraphBuilderTest : public HloVerifiedTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 069586a738..50f39cbcb5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1123,6 +1123,11 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + // If the module has a schedule, it must be valid. + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Verify()); + } + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 0cac210c24..8f0423bb1c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); auto module = CreateNewModule(); @@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 37b774b8a5..06f0e1ed25 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) TF_ASSIGN_OR_RETURN( - std::unique_ptr<Literal> inner_broadcast_result, + Literal inner_broadcast_result, broadcast_const_operand->literal().Broadcast( scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); @@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + opcode, scalar_indexed_const->literal(), inner_broadcast_result))); } else { TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + opcode, inner_broadcast_result, scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 9746d176cc..df9cbab915 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -347,21 +347,19 @@ class IndexedArrayAnalysis { } } - Literal* TakeOwnership(std::unique_ptr<Literal> literal) { + Literal* TakeOwnership(Literal literal) { owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } - StatusOr<Literal*> TakeOwnership( - StatusOr<std::unique_ptr<Literal>> literal_or_error) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal, - std::move(literal_or_error)); + StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } std::vector<std::unique_ptr<Array>> owned_tensors_; - std::vector<std::unique_ptr<Literal>> owned_literals_; + std::vector<Literal> owned_literals_; tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_; }; diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 5695bc2420..7e967f035c 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloTestBase; +using InlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(InlinerTest, MapMax) { @@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } // Test that `constant` function is changed to `broadcast`. @@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 8c907eae0c..3fdc2cee9a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible( return do_not_duplicate; } +namespace { + +// A FusionQueue that uses reverse post order. +// +// We want to be able to remove arbitrary instructions from the post order and +// also compare positions of instructions in the post order. To make this +// possible, create vector of instructions in post order and create a map from +// HloInstruction* to the instruction's index in the vector. An instruction is +// "removed" from the vector by setting it's element to nullptr. +class ReversePostOrderFusionQueue : public FusionQueue { + public: + explicit ReversePostOrderFusionQueue(HloComputation* computation) { + post_order_ = computation->MakeInstructionPostOrder(); + + for (size_t i = 0; i < post_order_.size(); ++i) { + InsertOrDie(&post_order_index_, post_order_[i], i); + } + } + + std::pair<HloInstruction*, std::vector<int64>> + DequeueNextInstructionAndOperandsToFuseInOrder() override { + // Instructions are "removed" from the post order by nulling out the element + // in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + while (!post_order_.empty() && post_order_.back() == nullptr) { + post_order_.pop_back(); + } + if (post_order_.empty()) { + return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}}; + } + // We want to iterate in reverse post order, so remove from the back of the + // vector. + HloInstruction* instruction = post_order_.back(); + post_order_.pop_back(); + + CHECK(instruction != nullptr); + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index_.erase(instruction); + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid creating duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the order + // they appear int the queue. In the example, this ensures that B will be + // considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector<int64> sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the queue. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) { + continue; + } + sorted_operand_numbers.push_back(i); + } + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return ( + FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + FindOrDie(post_order_index_, instruction->mutable_operand(j))); + }); + return std::make_pair(instruction, sorted_operand_numbers); + } + + void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) override { + // Fusing an instruction into a fusion instruction can change the operand + // set of the fusion instruction. For simplicity just re-enqueue the + // instruction and reconsider it for further fusion in the next iteration. + InsertOrDie(&post_order_index_, fusion, post_order_.size()); + post_order_.push_back(fusion); + } + + void RemoveInstruction(HloInstruction* instruction) override { + post_order_[FindOrDie(post_order_index_, instruction)] = nullptr; + post_order_index_.erase(instruction); + } + + private: + std::vector<HloInstruction*> post_order_; + tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_; +}; + +} // namespace + +std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue( + HloComputation* computation, + const std::function<bool(HloInstruction*)>& skip_producer) { + return absl::make_unique<ReversePostOrderFusionQueue>(computation); +} + StatusOr<bool> InstructionFusion::Run(HloModule* module) { VLOG(2) << "Before instruction fusion:"; XLA_VLOG_LINES(2, module->ToString()); @@ -306,111 +439,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = computation_->ComputeReachability(); - // We want to be able to remove arbitrary instructions from the post order - // and also compare positions of instructions in the post order. To make - // this possible, create vector of instructions in post order and create a - // map from HloInstruction* to the instruction's index in the vector. An - // instruction is "removed" from the vector by setting it's element to - // nullptr. - std::vector<HloInstruction*> post_order = - computation_->MakeInstructionPostOrder(); - - tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index; - for (size_t i = 0; i < post_order.size(); ++i) { - InsertOrDie(&post_order_index, post_order[i], i); - } - - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); + HloInstructionSet do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + auto fusion_queue = + GetFusionQueue(computation_, [&](HloInstruction* producer) { + return do_not_duplicate.count(producer) > 0; + }); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the // fusion instruction. - while (!post_order.empty()) { - // We want to iterate in reverse post order, so remove from the back of - // the vector. - HloInstruction* instruction = post_order.back(); - post_order.pop_back(); - - // Instructions are "removed" from the post order by nulling out the - // element in the vector, so if the pointer is null, continue to the next - // instruction in the sort. + while (true) { + auto next_entry = + fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); + auto instruction = next_entry.first; if (instruction == nullptr) { - continue; + break; } - // Remove instruction from the index map to ensure the vector and map stay - // consistent. - post_order_index.erase(instruction); - if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } - // Consider each operand of this instruction for fusion into this - // instruction. We want to consider the operands in a particular order to - // avoid creating duplicate instruction clones in the fusion instruction. - // For example, consider the following expression: - // - // A = ... - // B = op(A) - // C = op(A, B) - // - // If we are considering the operands of C for fusion into C. We might - // fuse A or B first. If we fuse A first, we get: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // C' = op(A', B) } - // - // Where A' and C' are clones of A and C, respectively. Now only B is an - // operand of the fusion instruction C_fusion, so then we fuse B: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // B' = op(A) - // C' = op(A', B') } - // - // Now A is an operand of C_fusion again, so we then fuse A (again!): - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // A" = .. - // B' = op(A") - // C' = op(A', B') } - // - // We prevent this duplication by considering the operands in the reverse - // order they appear in the instruction post order. In the example, this - // ensures that B will be considered before A. - // - // We store the original indices of the operands to pass to ShouldFuse. - std::vector<int64> sorted_operand_numbers; - sorted_operand_numbers.reserve(instruction->operands().size()); - for (int i = 0; i < instruction->operands().size(); ++i) { - // This will happen if we have two possible instructions to fuse the - // same operand into; once the operand is fused into one instruction, - // the other instruction will get a new get-tuple-element as its - // operand, which is not in the post-order index. - // TODO(tjoerg): Look into fusing past these multi-output fuse points. - if (post_order_index.find(instruction->mutable_operand(i)) == - post_order_index.end()) { - continue; - } - sorted_operand_numbers.push_back(i); - } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { - // Instructions with higher indices in the post order come - // first. - return ( - FindOrDie(post_order_index, instruction->mutable_operand(i)) > - FindOrDie(post_order_index, instruction->mutable_operand(j))); - }); + std::vector<int64>& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); @@ -425,32 +478,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { // TODO(tjoerg): Consider making multi-output fusion the default. if (ShouldFuse(instruction, i) && do_not_duplicate.count(operand) == 0) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = FuseIntoMultiOutput(operand, instruction); } else { continue; } - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); + fusion_queue->OnFusingInstruction(fusion_instruction, operand, + instruction); changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting its - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - + do_not_duplicate.erase(operand); + // Operand is now dead. Remove from queue. + fusion_queue->RemoveInstruction(operand); // Remove from computation. TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + + if (fusion_instruction != instruction) { + do_not_duplicate.erase(instruction); + } break; } } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 00b658959a..c1fde8ecfc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,6 +24,33 @@ limitations under the License. namespace xla { +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair<HloInstruction*, std::vector<int64>> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface { static bool IsExpensive(const HloInstruction& instruction); protected: + // Returns a FusionQueue that implements custom order of instructions being + // fused. The default implementation processes consumers in reverse post + // order. + virtual std::unique_ptr<FusionQueue> GetFusionQueue( + HloComputation* computation, + const std::function<bool(HloInstruction*)>& skip_producer); + // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. // Derived classes should define this method to specify which instructions diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 5dea124768..a06d6113e8 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream( // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. - std::vector<std::unique_ptr<Literal>> arg_literals; + std::vector<Literal> arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal, + TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - std::unique_ptr<Literal> result_literal; + Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate<std::unique_ptr<Literal>>( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>( + *computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, transfer_manager->AllocateScopedShapedBuffer( - result_literal->shape(), run_options->allocator(), + result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), *result_literal, result)); + run_options->stream(), result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 69c7e42601..752a61476d 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,7 +49,7 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloTestBase { +class LayoutAssignmentTest : public HloVerifiedTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout, @@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); - Shape ashape = constant_literal1->shape(); + Shape ashape = constant_literal1.shape(); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(constant_literal1))); @@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module.get()) + .Run(module) .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. @@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, f32_4, "param")); auto broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_34, param, {3})); + HloInstruction::CreateBroadcast(f32_34, param, {1})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); auto broadcast2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); auto module = CreateNewModule(); @@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + EXPECT_IS_OK(layout_assignment.Run(module).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); - module = + std::unique_ptr<HloModule> compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); EXPECT_EQ(Status::OK(), backend() .compiler() - ->RunBackend(std::move(module), + ->RunBackend(std::move(compiled_module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .status()); @@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(&module(), &computation_layout); - EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module.get()).status(); + Status error_status = layout_assignment.Run(module).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(&module(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); - EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::GetSubshape( - FindInstruction(module.get(), "send")->shape(), {0}), - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); + EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { @@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest, } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index f0e2566a3f..b27a92f2a0 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments, module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr<Literal> literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); - *module->add_arguments() = literal->ToProto(); + *module->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( - std::unique_ptr<Literal> literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, result)); - *module->mutable_result() = literal->ToProto(); + *module->mutable_result() = literal.ToProto(); return Status::OK(); } @@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, HloModule::CreateFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( - std::unique_ptr<Literal> result_literal, + Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, - result_literal->shape())) { - *result->mutable_literal() = result_literal->ToProto(); + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal->Relayout(*return_shape)->ToProto(); + result_literal.Relayout(*return_shape).ToProto(); } return Status::OK(); } @@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice( Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); std::vector<se::StreamExecutor*> replicas; if (arg->has_device_handle()) { @@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), *literal, shaped_buffer)); + stream.get(), literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, *literal); + return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, + literal); } Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, @@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), *literal)); - *result->mutable_literal() = literal->ToProto(); + executor, arg->shape_with_layout(), literal)); + *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModule::CreateFromProto(arg->computation(), config)); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate<std::unique_ptr<Literal>>( - *module, /*arg_literals=*/{})); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>( + *module, /*arg_literals=*/{})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); + result_literal = result_literal.Relayout(arg->output_layout()); } - *result->mutable_literal() = result_literal->ToProto(); + *result->mutable_literal() = result_literal.ToProto(); return Status::OK(); } @@ -1162,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1170,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 44c5248b15..1f62fad4c8 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -271,7 +271,9 @@ class Service : public ServiceInterface { StatusOr<std::vector<se::StreamExecutor*>> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc deleted file mode 100644 index dd53c7531b..0000000000 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/source_map_util.h" - -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace source_map_util { -namespace { - -Status InvalidParameterArgumentV(const OpMetadata& op_metadata, - const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - if (!op_metadata.source_file().empty()) { - absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), - op_metadata.source_line()); - } - return InvalidArgument("%s", message); -} - -} // namespace - -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidParameterArgumentV(op_metadata, format, args); - va_end(args); - return result; -} - -Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) { - va_list args; - va_start(args, format); - if (executable != nullptr && executable->has_module()) { - const HloModule& module = executable->module(); - const HloComputation& computation = *module.entry_computation(); - HloInstruction* param = computation.parameter_instruction(parameter_number); - const OpMetadata& metadata = param->metadata(); - Status result = InvalidParameterArgumentV(metadata, format, args); - va_end(args); - return result; - } - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -} // namespace source_map_util -} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index b8d2d546e5..a21e586efa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice( +StatusOr<Literal> TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { - StatusOr<std::unique_ptr<Literal>> ret; + StatusOr<Literal> ret; se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); @@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return absl::make_unique<Literal>(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferLiteralFromDevice( @@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice( return substream->BlockHostUntilDone(); } -StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice( +StatusOr<Literal> TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { - StatusOr<std::unique_ptr<Literal>> ret; + StatusOr<Literal> ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return absl::make_unique<Literal>(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferArrayToDevice( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 21725946b3..f952e64af2 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -57,7 +57,7 @@ class TransferManager { // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. - virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice( + virtual StatusOr<Literal> TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, @@ -113,9 +113,9 @@ class TransferManager { Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source); + StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream, + const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 2b2a2eb42a..e9a07b14ed 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1<float>({2.0, 42}).get()}))); + Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}), + LiteralUtil::CreateR1<float>({2.0, 42})}; + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 39b693872d..516754e211 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloVerifiedTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { EXPECT_THAT(computation->root_instruction(), gte); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { EXPECT_THAT(computation->root_instruction(), element); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + Run(module, /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index c3c2603c7e..541b117e02 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op, HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr<std::unique_ptr<Literal>> indvar_init_result = - evaluator.Evaluate(indvar_init); + StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); @@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op, auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); // The initial value of the induction variable. - std::unique_ptr<Literal> indvar_iter_val = - std::move(indvar_init_result).ValueOrDie(); + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); for (int64 trip_count = 0; trip_count != max_value_returned + 1; ++trip_count) { auto* while_cond = while_op->while_condition(); auto* while_cond_root = while_cond->root_instruction(); auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr<std::unique_ptr<Literal>> result = - evaluator.EvaluateWithSubstitutions( - while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) { + if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. - StatusOr<std::unique_ptr<Literal>> indvar_next_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, - {{while_body_indvar, indvar_iter_val.get()}}); + StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); |