aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/multi_output_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/multi_output_fusion.cc')
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc25
1 files changed, 2 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 79b5a442aa..4166ef5baf 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
HloInstruction* fused = instr2;
// Make sure that if only one of the instructions is a fusion, or if only one
// of the instructions is a multi-output fusion, it's what will be fused into.
- //
- // An invariant is that no bitcast nodes will show up in the middle of a
- // fusion node. This invariant must hold in order for us to lower it. Given
- // that, we require that during multi-output fusion, a fusion node ending with
- // bitcast to preserve its structure as a nested fusion instead being
- // merged and flattened.
- if (fused->opcode() == HloOpcode::kFusion &&
- fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ if (fused->opcode() == HloOpcode::kFusion) {
std::swap(remaining, fused);
}
if (fused->IsMultiOutputFusion()) {
std::swap(remaining, fused);
}
- if (fused->opcode() == HloOpcode::kFusion &&
- fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ if (fused->opcode() == HloOpcode::kFusion) {
remaining->MergeFusionInstructionIntoMultiOutput(fused);
} else {
- if (remaining->opcode() == HloOpcode::kFusion &&
- remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) {
- auto parent_computation = remaining->parent();
- // Create a nested fusion node.
- auto remaining_nested_fused =
- parent_computation->AddInstruction(HloInstruction::CreateFusion(
- remaining->shape(), HloInstruction::FusionKind::kLoop,
- remaining));
- TF_CHECK_OK(parent_computation->ReplaceInstruction(
- remaining, remaining_nested_fused));
- remaining = remaining_nested_fused;
- }
remaining->FuseInstructionIntoMultiOutput(fused);
}
-
return remaining;
}