aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc33
2 files changed, 37 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 1bd88233e1..30c1f90889 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -225,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) {
+ if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
- user->fusion_kind() == HloInstruction::FusionKind::kInput);
+ (user->fusion_kind() == HloInstruction::FusionKind::kInput &&
+ LayoutsAreReduceInputFusionFriendly(*fusion, *user)));
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Some of its users are not loop/input fusion kernels.";
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index b22bb1d39b..7cc869ed9e 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
op::Fusion(op::Parameter()));
}
+TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
+ auto module = ParseHloString(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+ add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
+ // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
+ ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,256]{0,1,2} parameter(0)
+ f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .ValueOrDie();
+ EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
+}
+
} // namespace
} // namespace gpu
} // namespace xla