From ee6265c30edcfe500c3caa5dec338f4fddd4943f Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Thu, 30 Aug 2018 05:38:27 -0700 Subject: [XLA:GPU] Do not merge loop fusions into reduce input fusions if the resulting reduce kernel suffers from poor data locality. PiperOrigin-RevId: 210894866 --- .../compiler/xla/service/gpu/fusion_merger.cc | 6 ++-- .../compiler/xla/service/gpu/fusion_merger_test.cc | 33 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 1bd88233e1..30c1f90889 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -225,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - user->fusion_kind() == HloInstruction::FusionKind::kInput); + (user->fusion_kind() == HloInstruction::FusionKind::kInput && + LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() << ": Some of its users are not loop/input fusion kernels."; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index b22bb1d39b..7cc869ed9e 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { op::Fusion(op::Parameter())); } +TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { + auto module = ParseHloString(R"( + HloModule m + + f1_computation { + f1_p0 = f32[16,16,256]{0,1,2} parameter(0) + add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0) + // Note that the copy changes the layout from {0,1,2} to {2,1,0}. + ROOT f1_root = f32[16,16,256]{2,1,0} copy(add) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[16,16,256]{2,1,0} parameter(0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[16,16,256]{0,1,2} parameter(0) + f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla -- cgit v1.2.3