diff options
author | 2018-08-03 05:23:47 -0700 | |
---|---|---|
committer | 2018-08-03 05:28:00 -0700 | |
commit | b2933c618260edc039fb8a7e2dce4d2e185f0892 (patch) | |
tree | 57aabd3d47f32b8f1bdd196c6859a255a8dea1c4 /tensorflow | |
parent | 37b48fac2c365c49373467abf5fc58c4678e700e (diff) |
[XLA:GPU] Allow multi-output fusion of element-wise instructions, in addition to loop fusions.
PiperOrigin-RevId: 207253181
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc | 20 |
2 files changed, 31 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c67dcbce77..c62bae0628 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -115,15 +115,23 @@ bool IsInputFusibleReduction(HloInstruction* instr) { // will be broadcasted and have not been observed to cause data locality issues. // TODO(b/111977086): Improve reduce emitters to remove this limitation. bool ReduceFriendlyInputLayouts(HloInstruction* instr) { + std::vector<HloInstruction*> params; + if (instr->opcode() == HloOpcode::kFusion) { + params = instr->fused_parameters(); + } else { + for (HloInstruction* operand : instr->operands()) { + params.push_back(operand); + } + } int64 max_rank = 0; const Layout* max_rank_layout; - for (HloInstruction* param : instr->fused_parameters()) { + for (HloInstruction* param : params) { if (ShapeUtil::Rank(param->shape()) > max_rank) { max_rank = ShapeUtil::Rank(param->shape()); max_rank_layout = ¶m->shape().layout(); } } - return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) { + return c_all_of(params, [&](HloInstruction* param) { return (ShapeUtil::Rank(param->shape()) < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); @@ -221,7 +229,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { const bool is_loop_fusion = producer->opcode() == HloOpcode::kFusion && producer->fusion_kind() == HloInstruction::FusionKind::kLoop; - if (!is_loop_fusion) { + if (!producer->IsElementwise() && !is_loop_fusion) { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index ec4234b8d9..14f157a5e5 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -256,6 +256,26 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Exp())); +} + TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_add { |