diff options
author | 2017-08-27 08:30:51 -0700 | |
---|---|---|
committer | 2017-08-27 08:34:32 -0700 | |
commit | f938347aeba95afa55d5dd0d3f911689eff31821 (patch) | |
tree | 7b71fe2a83fe20bd6317bbae48220113c786d9cf /tensorflow/compiler/xla/service/instruction_fusion.cc | |
parent | feb4e648bf72ebc6e6dc377e95329a93821e5eba (diff) |
Minor bugfix to HloInstruction::MergeFusionInstruction, and allow InstructionFusion::Fuse to be overridden by derived classes.
PiperOrigin-RevId: 166631382
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.cc | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index edfcb0922d..d449c637b5 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -212,7 +212,7 @@ bool InstructionFusion::CanFuseOnAllPaths( StatusOr<bool> InstructionFusion::Run(HloModule* module) { bool changed = false; - + module_ = module; std::vector<HloComputation*> computations; for (auto& computation : module->computations()) { if (computation->IsFusionComputation()) { @@ -395,7 +395,6 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); - auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -407,8 +406,8 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } - fusion_instruction->FuseInstruction(producer); + fusion_instruction->FuseInstruction(producer); return fusion_instruction; } @@ -423,13 +422,15 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, if (consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() != HloInstruction::FusionKind::kLoop && - consumer->fusion_kind() != HloInstruction::FusionKind::kInput) { + consumer->fusion_kind() != HloInstruction::FusionKind::kInput && + consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) { return false; } - // Cost condition: not fuse (expensive producers) and (consumers who reuse - // operand elements). - if (consumer->ReusesOperandElements(operand_index) && + // Cost condition: not fuse (simple, expensive producers) and (consumers who + // reuse operand elements). + if (producer->opcode() != HloOpcode::kFusion && + consumer->ReusesOperandElements(operand_index) && is_expensive_(*producer)) { return false; } |