aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Thomas Joerg <tjoerg@google.com>2018-08-30 05:38:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 05:42:53 -0700
commitee6265c30edcfe500c3caa5dec338f4fddd4943f (patch)
treeb54e5214f2f90c718e038dcc183fc65ca7dc479a
parent8032d9f79e2ed5ccdae642257da21a989965c8cf (diff)
[XLA:GPU] Do not merge loop fusions into reduce input fusions if the resulting reduce kernel suffers from poor data locality.
PiperOrigin-RevId: 210894866
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc33
2 files changed, 37 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 1bd88233e1..30c1f90889 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -225,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) {
+ if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
- user->fusion_kind() == HloInstruction::FusionKind::kInput);
+ (user->fusion_kind() == HloInstruction::FusionKind::kInput &&
+ LayoutsAreReduceInputFusionFriendly(*fusion, *user)));
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Some of its users are not loop/input fusion kernels.";
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index b22bb1d39b..7cc869ed9e 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
op::Fusion(op::Parameter()));
}
+TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
+ auto module = ParseHloString(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+ add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
+ // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
+ ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,256]{0,1,2} parameter(0)
+ f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .ValueOrDie();
+ EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
+}
+
} // namespace
} // namespace gpu
} // namespace xla