diff options
9 files changed, 56 insertions, 103 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 84bdd5acac..b02138325e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -616,7 +616,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( auto random_value = [hlo]() { const HloModule* module = - hlo->IsFused() ? hlo->fusion_instruction()->parent()->parent() + hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() : hlo->parent()->parent(); return module->RandomNew64(); }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index e7d32a4ae1..749badf3f2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -894,7 +894,7 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &ir_builder_); const HloInstruction* output = - reduce->IsFused() ? reduce->fusion_instruction() : reduce; + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, "output_element_address"); @@ -1142,7 +1142,7 @@ Status IrEmitterUnnested::EmitRowReduction( } const HloInstruction* output = - reduce->IsFused() ? reduce->fusion_instruction() : reduce; + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; // Emit an atomic operation that accumulates the partial reduction result of // lane 0 (which holds the partially accumulated result for its warp) to the @@ -1913,10 +1913,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - // const HloInstruction* root = hlo.fused_expression_root(); - llvm_ir::EmitTuple( - GetIrArray(*hlo.fused_expression_root()->fusion_instruction()), - tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index cecbb01ff8..ccdd171759 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -308,7 +308,7 @@ class WhileConditionComputationMatcher : public MatcherBase { GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->fusion_instruction()->operand( + if (gte_fusion_param0->parent()->FusionInstruction()->operand( gte_fusion_param0->parameter_number()) != computation_->parameter_instruction(0)) { return InvalidArgument("Could not match fusion param: %s", @@ -469,7 +469,8 @@ class WhileBodyComputationMatcher : public MatcherBase { // Fusion parameter: lookup and compare with associated fusion operand. CHECK_EQ(HloOpcode::kParameter, inst->opcode()); CHECK(inst->IsFused()); - if (inst->fusion_instruction()->operand(inst->parameter_number()) != + if (inst->parent()->FusionInstruction()->operand( + inst->parameter_number()) != computation_->parameter_instruction(0)) { return InvalidArgument("Could not match fusion param: %s", inst->name().c_str()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c030ceb72f..2d07784619 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -58,16 +58,16 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build( CHECK_NE(nullptr, root); return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, is_fusion_computation_)); + root, fusion_instruction_)); } HloComputation::HloComputation( const string& name, int parameter_count, std::vector<std::unique_ptr<HloInstruction>>* instructions, - HloInstruction* root_instruction, bool is_fusion_computation) + HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), root_instruction_(root_instruction), - is_fusion_computation_(is_fusion_computation) { + fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { @@ -112,11 +112,8 @@ HloInstruction* HloComputation::AddInstructionInternal( HloInstruction* HloComputation::AddParameter( std::unique_ptr<HloInstruction> instruction) { CHECK(instruction->opcode() == HloOpcode::kParameter); - CHECK(is_fusion_computation_); - CHECK(root_instruction_->fusion_instruction() != nullptr); - instruction->SetParentFusion(root_instruction_->fusion_instruction()); - CHECK(root_instruction_->fusion_instruction()->operand_count() == - param_instructions_.size()); + CHECK(IsFusionComputation()); + CHECK(fusion_instruction_->operand_count() == param_instructions_.size()); instruction->set_parent(this); param_instructions_.push_back(instruction.get()); AddInstructionInternal(std::move(instruction)); @@ -126,8 +123,7 @@ HloInstruction* HloComputation::AddParameter( Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); - CHECK(is_fusion_computation_); - CHECK(root_instruction_->fusion_instruction() != nullptr); + CHECK(IsFusionComputation()); HloInstruction* param_instruction = param_instructions_[param_no]; auto param_instruction_iterator = param_instructions_.begin() + param_no; param_instructions_.erase(param_instruction_iterator); @@ -155,7 +151,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); - new_instr->SetParentFusion(root_instruction_->fusion_instruction()); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); param_no++; @@ -166,10 +161,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { void HloComputation::Reparent(HloInstruction* instruction) { instruction->set_parent(this); - if (is_fusion_computation_ && instruction != root_instruction_) { - CHECK(root_instruction_->fusion_instruction() != nullptr); - instruction->SetParentFusion(root_instruction_->fusion_instruction()); - } } bool HloComputation::IsRemovable(const HloInstruction* instruction) { @@ -182,7 +173,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { } if (instruction->opcode() == HloOpcode::kParameter && - !is_fusion_computation_) { + !IsFusionComputation()) { return false; } @@ -267,7 +258,7 @@ void HloComputation::set_root_instruction( HloInstruction* new_root_instruction) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!is_fusion_computation_) { + if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape().ShortDebugString() diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index f383a17fb8..576c44a9f3 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -56,10 +56,11 @@ class HloComputation { // Builder class for HloComputation. class Builder { public: - explicit Builder(const string& name, bool is_fusion_computation = false) + explicit Builder(const string& name, + HloInstruction* fusion_instruction = nullptr) : name_(name), last_added_instruction_(nullptr), - is_fusion_computation_(is_fusion_computation) {} + fusion_instruction_(fusion_instruction) {} // Build and return an HloComputation. The parameter root_instruction // specifies the already-added instruction to use as the root. If @@ -78,7 +79,7 @@ class HloComputation { private: const string name_; HloInstruction* last_added_instruction_; - bool is_fusion_computation_; + HloInstruction* fusion_instruction_; std::vector<std::unique_ptr<HloInstruction>> instructions_; }; @@ -274,13 +275,18 @@ class HloComputation { bool HasSideEffect() const; // Returns if this computation is a fusion computation. - bool IsFusionComputation() const { return is_fusion_computation_; } + bool IsFusionComputation() const { return fusion_instruction_ != nullptr; } + + // Returns the owning fusion instruction, or nullptr if this is not a fusion + // computation. + HloInstruction* FusionInstruction() const { return fusion_instruction_; } private: explicit HloComputation( const string& name, int parameter_count, std::vector<std::unique_ptr<HloInstruction>>* instructions, - HloInstruction* root_instruction, bool is_fusion_computation = false); + HloInstruction* root_instruction, + HloInstruction* fusion_instruction = nullptr); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( @@ -309,8 +315,9 @@ class HloComputation { string name_; HloInstruction* root_instruction_; - // A tag shows if this is a fusion computation. - bool is_fusion_computation_; + // If this computation is a fusion computation, this field points to the + // corresponding fusion instruction. Otherwise, this is null. + HloInstruction* fusion_instruction_; // Module containing this computation. HloModule* parent_ = nullptr; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 24a47f80af..dfb111d1d0 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -479,7 +479,7 @@ stylesheet=" // If this edge crosses a fusion cluster boundary, highlight it when the // cluster is hovered over. if (from_node->IsFused() && - from_node->fusion_instruction()->fused_expression_root() == from_node) { + from_node->parent()->root_instruction() == from_node) { int64 cluster_id = cluster_ids_.at(from_node->parent()); add_hover_css_rule("clust", cluster_id, kBlue); } @@ -657,7 +657,7 @@ string HloDotDumper::GetInstructionNodeInlinedConstants( // Special case: If instr is a parameter to a fusion node, check whether the // corresponding operand to the fusion node is a constant. if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->fusion_instruction(); + const HloInstruction* fusion = instr->parent()->FusionInstruction(); const HloInstruction* operand = fusion->operand(instr->parameter_number()); if (operand->opcode() != HloOpcode::kConstant) { return ""; @@ -898,7 +898,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { // expressions are handled specially -- we draw an edge from the corresponding // operand on the fusion node itself to the parameter. if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->fusion_instruction(); + const HloInstruction* fusion = instr->parent()->FusionInstruction(); add_edge(fusion->operand(instr->parameter_number()), instr, /*operand_num=*/0); } else { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 237e745383..3bdb67ba92 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -649,16 +649,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( if (called_computations_.empty()) { // New fusion instruction. It should not be a multioutput instruction. CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", true); + auto builder = HloComputation::Builder("fused_computation", this); builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); called_computations_.push_back( CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); - clone->parent_fusion_instruction_ = this; } else { clone = fused_instructions_computation()->AddInstruction( instruction_to_fuse->Clone(/*suffix=*/"")); - clone->parent_fusion_instruction_ = this; // When add_output is false, instruction_to_fuse is necessarily an operand // of the fusion instruction. After fusion this will no longer be the case. // Remove the operand from the operand list and remove its corresponding @@ -727,12 +725,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // to avoid a double %%. string param_name = StrCat(operand->name().substr(1), ".param_", param_no); - std::unique_ptr<HloInstruction> param_instruction = - CreateParameter(param_no, operand->shape(), param_name); - - param_instruction->parent_fusion_instruction_ = this; fused_param = fused_instructions_computation()->AddParameter( - std::move(param_instruction)); + CreateParameter(param_no, operand->shape(), param_name)); AppendOperand(operand); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); @@ -762,7 +756,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction::CreateTuple(tuple_elements)); fused_instructions_computation()->set_root_instruction(new_root); shape_ = new_root->shape(); - new_root->parent_fusion_instruction_ = this; if (fused_root->opcode() == HloOpcode::kTuple) { TF_CHECK_OK( fused_instructions_computation()->RemoveInstruction(fused_root)); @@ -839,24 +832,17 @@ bool HloInstruction::HasSideEffect() const { void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ = - fused_instructions_computation()->instructions(); - // All instructions owned by this fusion instruction must be fused, and the - // parent fusion instruction of the fused instructions must be 'this'. - for (auto& instruction : fused_instructions_) { - CHECK(instruction->IsFused()); - CHECK_EQ(this, instruction->fusion_instruction()); - CHECK_EQ(fused_instructions_computation(), instruction->parent()) - << instruction->ToString(); - } + // The parent fusion instruction of the fusion computation must be 'this'. + HloComputation* fused_computation = fused_instructions_computation(); + CHECK_EQ(this, fused_computation->FusionInstruction()); // Fused root instruction and fused parameters must all be owned by the fusion - // instruction. + // computation. bool root_owned = false; const std::vector<HloInstruction*>& fused_parameters_ = fused_parameters(); const HloInstruction* fused_root_ = fused_expression_root(); std::vector<bool> parameter_owned(fused_parameters_.size(), false); - for (auto& instruction : fused_instructions_) { + for (auto& instruction : fused_computation->instructions()) { if (fused_root_ == instruction.get()) { CHECK(!root_owned); root_owned = true; @@ -877,14 +863,13 @@ void HloInstruction::CheckFusionInstruction() const { // Fused root must have no users. CHECK_EQ(0, fused_root_->user_count()); - // All uses of fused instructions must be in the fusion instruction, and every + // All uses of fused instructions must be in the fusion computation, and every // non-root instruction must have at least one use. - for (auto& instruction : fused_instructions_) { + for (auto& instruction : fused_instructions_computation()->instructions()) { if (instruction.get() != fused_root_) { CHECK_GT(instruction->user_count(), 0); for (auto& user : instruction->users()) { - CHECK(user->IsFused()); - CHECK_EQ(this, user->fusion_instruction()); + CHECK_EQ(fused_computation, user->parent()); } } } @@ -1173,15 +1158,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( std::list<std::unique_ptr<HloInstruction>> new_fused_instructions; // Create the list of fused parameters by mapping through the cloned, // fused instructions. - std::vector<HloInstruction*> new_fused_parameters; - const std::vector<HloInstruction*>& fused_parameters_ = - fused_instructions_computation()->parameter_instructions(); - - for (HloInstruction* old_fused_parameter : fused_parameters_) { + for (HloInstruction* old_fused_parameter : + fused_instructions_computation()->parameter_instructions()) { new_fused_instructions.push_back(old_fused_parameter->Clone()); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - new_fusion_parameter->parent_fusion_instruction_ = new_instruction.get(); - new_fused_parameters.push_back(new_fusion_parameter); InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); } for (auto old_fused_instruction : @@ -1202,12 +1182,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( old_fused_instruction->shape(), new_operands)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); new_fused_instruction->set_parent(parent()); - new_fused_instruction->parent_fusion_instruction_ = new_instruction.get(); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); } new_instruction->fusion_kind_ = fusion_kind_; auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", true); + fused_instructions_computation()->name() + ".clone", + new_instruction.get()); // We iterated the fusion instructions in reverse post order which means // that we must reverse our new list of fusion instructions. for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); @@ -1912,9 +1892,7 @@ string HloInstruction::TracingTag() const { return literal_->u8s_string(); } -bool HloInstruction::IsFused() const { - return parent_fusion_instruction_ != nullptr; -} +bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { // Instructions which are traced should not be fused. @@ -1949,11 +1927,6 @@ HloComputation* HloInstruction::fused_instructions_computation() const { return fused_instructions_computation; } -HloInstruction* HloInstruction::fusion_instruction() const { - CHECK(IsFused()); - return parent_fusion_instruction_; -} - HloInstruction* HloInstruction::fused_expression_root() const { CHECK_EQ(opcode_, HloOpcode::kFusion); return fused_instructions_computation()->root_instruction(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 923aeb47f0..5688fcc425 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -603,26 +603,21 @@ class HloInstruction { // instruction. bool IsFused() const; + // Returns the computation for this fused instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + HloComputation* fused_instructions_computation() const; + // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the fusion instruction that contains this instruction. - // - // Note: only valid if this instruction is fused into a fusion instruction. - HloInstruction* fusion_instruction() const; - // Returns the root instruction of the fused expression contained within this // fusion instruction. // // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns the list of fused instructions inside this fusioninstruction. // // Note: although the list itself is const, the instructions contained in the @@ -898,14 +893,6 @@ class HloInstruction { // instruction to make it a bitcast. bool CouldBeBitcast() const; - // Sets the parent fusion instruction for this instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - void SetParentFusion(HloInstruction* fusion_instruction) { - CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); - parent_fusion_instruction_ = fusion_instruction; - } - // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; @@ -1049,10 +1036,6 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr<PaddingConfig> padding_config_; - // If this instruction is fused into a fusion instruction, this field points - // to the fusion instruction. - HloInstruction* parent_fusion_instruction_ = nullptr; - // The type of the fusion. Used by kFusion only. FusionKind fusion_kind_; diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 76177462aa..5a4c93b59a 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -91,10 +91,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( string node_name; // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. - if (instruction->IsFused()) { - node_name = GetNodeNameForInstruction(instruction->fusion_instruction()); + const HloComputation* computation = instruction->parent(); + if (computation->IsFusionComputation()) { + node_name = GetNodeNameForInstruction(computation->FusionInstruction()); } else { - node_name = instruction->parent()->name(); + node_name = computation->name(); if (!instruction->metadata().op_name().empty()) { // Always make computations contain TF ops but not the other way around. StrAppend(&node_name, "/", instruction->metadata().op_name()); |