diff options
3 files changed, 7 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 652b5c7687..ea661b3c2c 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -113,10 +113,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { // We can fuse reduces and loop fusions. return IsInputFusibleReduction(instr) || (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop && - // TODO(b/110202584): bitcasts make nested fusions, GPU has no support - // for nested fusions. - instr->fused_expression_root()->opcode() != HloOpcode::kBitcast); + instr->fusion_kind() == HloInstruction::FusionKind::kLoop); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 79b5a442aa..4166ef5baf 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, HloInstruction* fused = instr2; // Make sure that if only one of the instructions is a fusion, or if only one // of the instructions is a multi-output fusion, it's what will be fused into. - // - // An invariant is that no bitcast nodes will show up in the middle of a - // fusion node. This invariant must hold in order for us to lower it. Given - // that, we require that during multi-output fusion, a fusion node ending with - // bitcast to preserve its structure as a nested fusion instead being - // merged and flattened. - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { std::swap(remaining, fused); } if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { - if (remaining->opcode() == HloOpcode::kFusion && - remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { - auto parent_computation = remaining->parent(); - // Create a nested fusion node. - auto remaining_nested_fused = - parent_computation->AddInstruction(HloInstruction::CreateFusion( - remaining->shape(), HloInstruction::FusionKind::kLoop, - remaining)); - TF_CHECK_OK(parent_computation->ReplaceInstruction( - remaining, remaining_nested_fused)); - remaining = remaining_nested_fused; - } remaining->FuseInstructionIntoMultiOutput(fused); } - return remaining; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d23822e33e..0019cd7254 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -78,6 +78,10 @@ class MultiOutputFusion : public HloPassInterface { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + // Recompute reachability for the current computation. void RecomputeReachability(); @@ -101,10 +105,6 @@ class MultiOutputFusion : public HloPassInterface { virtual bool DoProducerConsumerMultiOutputFusion(); private: - // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. - // The other instruction is removed from its parent computation. - HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); - // Update the internal data structures after instr1 and instr2 are fused into // one fusion instruction. void Update(HloInstruction* instr1, HloInstruction* instr2); |