diff options
author | Thomas Joerg <tjoerg@google.com> | 2018-08-31 00:50:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 00:54:59 -0700 |
commit | 8155dfc0119587ed07dd05a641480a9636f91b04 (patch) | |
tree | 6c418fb817684a6c945db33ea1789da9c36c9991 | |
parent | 702d0f8d8a72075531c6016f2d6fc6d936cee95e (diff) |
[XLA:GPU] Instruction fusion: Do not fuse into reduce input fusions if the resulting kernels suffer from poor data locality.
PiperOrigin-RevId: 211046900
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/instruction_fusion.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc | 72 |
2 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index de8d021321..4d5d8e99f8 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -221,6 +222,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Do not fuse into reduce input fusions if the resulting kernel would suffer + // from poor data locality (due to unfriendly input layouts). + if (IsInputFusibleReduction(*consumer) && + !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { + return false; + } + // We can't fuse library calls, so if a user of such an op could become a // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for // further rationale. diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index f53dfaee3d..bca775c475 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -171,6 +171,78 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + constant.1 = f32[] constant(0) + ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + fused_reduce { + p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0) + mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1) + c0.1 = f32[] constant(0) + ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce + ROOT root = (f32[]) tuple(fusion) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), op::Add(op::Copy(), op::Copy())); +} + TEST_F(InstructionFusionTest, BitcastIntoAdd) { auto module = ParseHloString(R"( HloModule test_module |