diff options
author | 2017-04-17 13:39:36 -0800 | |
---|---|---|
committer | 2017-04-17 14:59:42 -0700 | |
commit | 33fd4134234170745f989e2cdd73c8ca8709d926 (patch) | |
tree | c7fad80e48ba8cb8da246583669266c143a75543 /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | cc45456e4ad0eff16127d1727d0cf48afb71ca0e (diff) |
[XLA] Represent fusion instructions as a HloComputation
Using a HloComputation to represent the HloInstructions inside a fusion
instruction.
All the interfaces are kept the same except for the parent field of the fusion
instruction. It now points to the newly created HloComputation rather the
enclosing computation for the fusion instruction.
Change: 153390245
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 73 |
1 files changed, 52 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 15de24fffd..655546d715 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -52,16 +52,17 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build( root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique( - new HloComputation(name_, parameter_count, &instructions_, root)); + return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, + root, is_fusion_computation_)); } HloComputation::HloComputation( const string& name, int parameter_count, std::vector<std::unique_ptr<HloInstruction>>* instructions, - HloInstruction* root_instruction) + HloInstruction* root_instruction, bool is_fusion_computation) : name_(name), root_instruction_(root_instruction), + is_fusion_computation_(is_fusion_computation), instruction_name_uniquer_(/*separator=*/".") { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; @@ -99,19 +100,54 @@ HloInstruction* HloComputation::AddInstructionInternal( return pinst; } -void HloComputation::Reparent(HloInstruction* instruction) { +HloInstruction* HloComputation::AddParameter( + std::unique_ptr<HloInstruction> instruction) { + CHECK(instruction->opcode() == HloOpcode::kParameter); + CHECK(is_fusion_computation_); + CHECK(root_instruction_->fusion_instruction() != nullptr); + instruction->SetParentFusion(root_instruction_->fusion_instruction()); + CHECK(root_instruction_->fusion_instruction()->operand_count() == + param_instructions_.size()); instruction->set_parent(this); - if (instruction->opcode() == HloOpcode::kFusion) { - for (auto& i : instruction->fused_instructions()) { - Reparent(i.get()); - } + param_instructions_.push_back(instruction.get()); + AddInstructionInternal(std::move(instruction)); + return instructions_.back().get(); +} + +Status HloComputation::RemoveParameter(int64 param_no) { + CHECK_GE(param_no, 0); + CHECK_LT(param_no, param_instructions_.size()); + CHECK(is_fusion_computation_); + CHECK(root_instruction_->fusion_instruction() != nullptr); + HloInstruction* param_instruction = param_instructions_[param_no]; + auto param_instruction_iterator = param_instructions_.begin() + param_no; + param_instructions_.erase(param_instruction_iterator); + // Throw removed fused parameter instruction away. + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + + while (param_no < param_instructions_.size()) { + param_instruction = param_instructions_[param_no]; + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_instruction->name())); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + new_instr->SetParentFusion(root_instruction_->fusion_instruction()); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + param_no++; } + + return Status::OK(); } -/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) { - return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv || - opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace || - opcode == HloOpcode::kOutfeed); +void HloComputation::Reparent(HloInstruction* instruction) { + instruction->set_parent(this); +} + +bool HloComputation::IsRemovable(const HloOpcode& opcode) { + return !((opcode == HloOpcode::kParameter && !is_fusion_computation_) || + opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend || + opcode == HloOpcode::kTrace || opcode == HloOpcode::kOutfeed); } Status HloComputation::RemoveInstructionAndUnusedOperands( @@ -119,7 +155,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(root_instruction() != instruction); TF_RET_CHECK(instruction->user_count() == 0); - TF_RET_CHECK(HloComputation::IsRemovable(instruction->opcode())); + TF_RET_CHECK(IsRemovable(instruction->opcode())); std::unordered_set<HloInstruction*> removed; std::queue<HloInstruction*> worklist; worklist.push(instruction); @@ -128,8 +164,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || - !HloComputation::IsRemovable(item->opcode())) { + item == root_instruction() || !IsRemovable(item->opcode())) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -302,12 +337,8 @@ string HloComputation::ToString() const { for (const HloInstruction* instruction : MakeInstructionPostOrder()) { s << " " << instruction->ToString() << "\n"; if (instruction->opcode() == HloOpcode::kFusion) { - tensorflow::gtl::FlatSet<HloInstruction*> added_instructions; - auto fused_instructions = InstructionPostOrderer::GetOrder( - instruction->fused_expression_root(), &added_instructions); - for (const auto& fused_instruction : fused_instructions) { - s << " " << fused_instruction->ToString() << "\n"; - } + s << " " << instruction->fused_instructions_computation()->ToString() + << "\n"; } } s << "}"; |