diff options
author | 2017-08-27 08:30:51 -0700 | |
---|---|---|
committer | 2017-08-27 08:34:32 -0700 | |
commit | f938347aeba95afa55d5dd0d3f911689eff31821 (patch) | |
tree | 7b71fe2a83fe20bd6317bbae48220113c786d9cf | |
parent | feb4e648bf72ebc6e6dc377e95329a93821e5eba (diff) |
Minor bugfix to HloInstruction::MergeFusionInstruction, and allow InstructionFusion::Fuse to be overridden by derived classes.
PiperOrigin-RevId: 166631382
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.h | 7 |
3 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 28ca915310..f5a081a9dc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -532,6 +532,8 @@ void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); // Clone the instruction from which to merge fused instructions. std::unique_ptr<HloInstruction> clone = instruction_to_merge->Clone(); // Replace uses of fused parameters with the corresponding operand of the @@ -563,6 +565,11 @@ void HloInstruction::MergeFusionInstruction( } CHECK_EQ(0, clone->user_count()); clone->DetachFromOperands(); + + if (GetModule()) { + TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation( + clone->fused_instructions_computation())); + } } void HloInstruction::MergeFusionInstructionIntoMultiOutput( @@ -2131,6 +2138,7 @@ using DFSStack = // cycle was detected, and true otherwise. inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack, HloInstruction* child) { + CHECK(child != nullptr); const int id = child->unique_id(); CHECK_GE(id, 0) << "instruction may not have a parent computation"; switch (visitor->GetVisitState(id)) { @@ -2193,7 +2201,6 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, visitor->SetVisitState(current_id, DfsHloVisitor::kVisiting); const size_t old_dfs_stack_size = dfs_stack.size(); - for (HloInstruction* child : current_node->operands()) { if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index edfcb0922d..d449c637b5 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -212,7 +212,7 @@ bool InstructionFusion::CanFuseOnAllPaths( StatusOr<bool> InstructionFusion::Run(HloModule* module) { bool changed = false; - + module_ = module; std::vector<HloComputation*> computations; for (auto& computation : module->computations()) { if (computation->IsFusionComputation()) { @@ -395,7 +395,6 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); - auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -407,8 +406,8 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } - fusion_instruction->FuseInstruction(producer); + fusion_instruction->FuseInstruction(producer); return fusion_instruction; } @@ -423,13 +422,15 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, if (consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() != HloInstruction::FusionKind::kLoop && - consumer->fusion_kind() != HloInstruction::FusionKind::kInput) { + consumer->fusion_kind() != HloInstruction::FusionKind::kInput && + consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) { return false; } - // Cost condition: not fuse (expensive producers) and (consumers who reuse - // operand elements). - if (consumer->ReusesOperandElements(operand_index) && + // Cost condition: not fuse (simple, expensive producers) and (consumers who + // reuse operand elements). + if (producer->opcode() != HloOpcode::kFusion && + consumer->ReusesOperandElements(operand_index) && is_expensive_(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f6f37bb79b..3ac13ffda0 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -66,12 +66,15 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); + // Fuses producer into consumer. + virtual HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer); + // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; + HloModule* module_; private: - HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); - // The set of producers whose consumers we cannot fuse into. using DoNotFuseSet = std::unordered_set<HloInstruction*>; |