aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc5
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc25
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h8
3 files changed, 7 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 652b5c7687..ea661b3c2c 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -113,10 +113,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
// We can fuse reduces and loop fusions.
return IsInputFusibleReduction(instr) ||
(instr->opcode() == HloOpcode::kFusion &&
- instr->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- // TODO(b/110202584): bitcasts make nested fusions, GPU has no support
- // for nested fusions.
- instr->fused_expression_root()->opcode() != HloOpcode::kBitcast);
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 79b5a442aa..4166ef5baf 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
HloInstruction* fused = instr2;
// Make sure that if only one of the instructions is a fusion, or if only one
// of the instructions is a multi-output fusion, it's what will be fused into.
- //
- // An invariant is that no bitcast nodes will show up in the middle of a
- // fusion node. This invariant must hold in order for us to lower it. Given
- // that, we require that during multi-output fusion, a fusion node ending with
- // bitcast to preserve its structure as a nested fusion instead being
- // merged and flattened.
- if (fused->opcode() == HloOpcode::kFusion &&
- fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ if (fused->opcode() == HloOpcode::kFusion) {
std::swap(remaining, fused);
}
if (fused->IsMultiOutputFusion()) {
std::swap(remaining, fused);
}
- if (fused->opcode() == HloOpcode::kFusion &&
- fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ if (fused->opcode() == HloOpcode::kFusion) {
remaining->MergeFusionInstructionIntoMultiOutput(fused);
} else {
- if (remaining->opcode() == HloOpcode::kFusion &&
- remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) {
- auto parent_computation = remaining->parent();
- // Create a nested fusion node.
- auto remaining_nested_fused =
- parent_computation->AddInstruction(HloInstruction::CreateFusion(
- remaining->shape(), HloInstruction::FusionKind::kLoop,
- remaining));
- TF_CHECK_OK(parent_computation->ReplaceInstruction(
- remaining, remaining_nested_fused));
- remaining = remaining_nested_fused;
- }
remaining->FuseInstructionIntoMultiOutput(fused);
}
-
return remaining;
}
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d23822e33e..0019cd7254 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -78,6 +78,10 @@ class MultiOutputFusion : public HloPassInterface {
// Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
+ // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
+ // The other instruction is removed from its parent computation.
+ virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
+
// Recompute reachability for the current computation.
void RecomputeReachability();
@@ -101,10 +105,6 @@ class MultiOutputFusion : public HloPassInterface {
virtual bool DoProducerConsumerMultiOutputFusion();
private:
- // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
- // The other instruction is removed from its parent computation.
- HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
-
// Update the internal data structures after instr1 and instr2 are fused into
// one fusion instruction.
void Update(HloInstruction* instr1, HloInstruction* instr2);