diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 25 |
1 files changed, 8 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c030ceb72f..2d07784619 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -58,16 +58,16 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build( CHECK_NE(nullptr, root); return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, is_fusion_computation_)); + root, fusion_instruction_)); } HloComputation::HloComputation( const string& name, int parameter_count, std::vector<std::unique_ptr<HloInstruction>>* instructions, - HloInstruction* root_instruction, bool is_fusion_computation) + HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), root_instruction_(root_instruction), - is_fusion_computation_(is_fusion_computation) { + fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { @@ -112,11 +112,8 @@ HloInstruction* HloComputation::AddInstructionInternal( HloInstruction* HloComputation::AddParameter( std::unique_ptr<HloInstruction> instruction) { CHECK(instruction->opcode() == HloOpcode::kParameter); - CHECK(is_fusion_computation_); - CHECK(root_instruction_->fusion_instruction() != nullptr); - instruction->SetParentFusion(root_instruction_->fusion_instruction()); - CHECK(root_instruction_->fusion_instruction()->operand_count() == - param_instructions_.size()); + CHECK(IsFusionComputation()); + CHECK(fusion_instruction_->operand_count() == param_instructions_.size()); instruction->set_parent(this); param_instructions_.push_back(instruction.get()); AddInstructionInternal(std::move(instruction)); @@ -126,8 +123,7 @@ HloInstruction* HloComputation::AddParameter( Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); - CHECK(is_fusion_computation_); - CHECK(root_instruction_->fusion_instruction() != nullptr); + CHECK(IsFusionComputation()); HloInstruction* param_instruction = param_instructions_[param_no]; auto param_instruction_iterator = param_instructions_.begin() + param_no; param_instructions_.erase(param_instruction_iterator); @@ -155,7 +151,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); - new_instr->SetParentFusion(root_instruction_->fusion_instruction()); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); param_no++; @@ -166,10 +161,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { void HloComputation::Reparent(HloInstruction* instruction) { instruction->set_parent(this); - if (is_fusion_computation_ && instruction != root_instruction_) { - CHECK(root_instruction_->fusion_instruction() != nullptr); - instruction->SetParentFusion(root_instruction_->fusion_instruction()); - } } bool HloComputation::IsRemovable(const HloInstruction* instruction) { @@ -182,7 +173,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { } if (instruction->opcode() == HloOpcode::kParameter && - !is_fusion_computation_) { + !IsFusionComputation()) { return false; } @@ -267,7 +258,7 @@ void HloComputation::set_root_instruction( HloInstruction* new_root_instruction) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!is_fusion_computation_) { + if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape().ShortDebugString() |