aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h27
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc7
9 files changed, 56 insertions, 103 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 84bdd5acac..b02138325e 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -616,7 +616,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
auto random_value = [hlo]() {
const HloModule* module =
- hlo->IsFused() ? hlo->fusion_instruction()->parent()->parent()
+ hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent()
: hlo->parent()->parent();
return module->RandomNew64();
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index e7d32a4ae1..749badf3f2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -894,7 +894,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
&ir_builder_);
const HloInstruction* output =
- reduce->IsFused() ? reduce->fusion_instruction() : reduce;
+ reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_,
"output_element_address");
@@ -1142,7 +1142,7 @@ Status IrEmitterUnnested::EmitRowReduction(
}
const HloInstruction* output =
- reduce->IsFused() ? reduce->fusion_instruction() : reduce;
+ reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
// Emit an atomic operation that accumulates the partial reduction result of
// lane 0 (which holds the partially accumulated result for its warp) to the
@@ -1913,10 +1913,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
- // const HloInstruction* root = hlo.fused_expression_root();
- llvm_ir::EmitTuple(
- GetIrArray(*hlo.fused_expression_root()->fusion_instruction()),
- tuple_operand_ptrs, &ir_builder_);
+ llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
index cecbb01ff8..ccdd171759 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -308,7 +308,7 @@ class WhileConditionComputationMatcher : public MatcherBase {
GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions));
CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode());
CHECK(gte_fusion_param0->IsFused());
- if (gte_fusion_param0->fusion_instruction()->operand(
+ if (gte_fusion_param0->parent()->FusionInstruction()->operand(
gte_fusion_param0->parameter_number()) !=
computation_->parameter_instruction(0)) {
return InvalidArgument("Could not match fusion param: %s",
@@ -469,7 +469,8 @@ class WhileBodyComputationMatcher : public MatcherBase {
// Fusion parameter: lookup and compare with associated fusion operand.
CHECK_EQ(HloOpcode::kParameter, inst->opcode());
CHECK(inst->IsFused());
- if (inst->fusion_instruction()->operand(inst->parameter_number()) !=
+ if (inst->parent()->FusionInstruction()->operand(
+ inst->parameter_number()) !=
computation_->parameter_instruction(0)) {
return InvalidArgument("Could not match fusion param: %s",
inst->name().c_str());
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c030ceb72f..2d07784619 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -58,16 +58,16 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
CHECK_NE(nullptr, root);
return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
- root, is_fusion_computation_));
+ root, fusion_instruction_));
}
HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction, bool is_fusion_computation)
+ HloInstruction* root_instruction, HloInstruction* fusion_instruction)
: name_(name),
root_instruction_(root_instruction),
- is_fusion_computation_(is_fusion_computation) {
+ fusion_instruction_(fusion_instruction) {
param_instructions_.resize(parameter_count, nullptr);
bool root_found = false;
for (auto& instruction : *instructions) {
@@ -112,11 +112,8 @@ HloInstruction* HloComputation::AddInstructionInternal(
HloInstruction* HloComputation::AddParameter(
std::unique_ptr<HloInstruction> instruction) {
CHECK(instruction->opcode() == HloOpcode::kParameter);
- CHECK(is_fusion_computation_);
- CHECK(root_instruction_->fusion_instruction() != nullptr);
- instruction->SetParentFusion(root_instruction_->fusion_instruction());
- CHECK(root_instruction_->fusion_instruction()->operand_count() ==
- param_instructions_.size());
+ CHECK(IsFusionComputation());
+ CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
instruction->set_parent(this);
param_instructions_.push_back(instruction.get());
AddInstructionInternal(std::move(instruction));
@@ -126,8 +123,7 @@ HloInstruction* HloComputation::AddParameter(
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
- CHECK(is_fusion_computation_);
- CHECK(root_instruction_->fusion_instruction() != nullptr);
+ CHECK(IsFusionComputation());
HloInstruction* param_instruction = param_instructions_[param_no];
auto param_instruction_iterator = param_instructions_.begin() + param_no;
param_instructions_.erase(param_instruction_iterator);
@@ -155,7 +151,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
- new_instr->SetParentFusion(root_instruction_->fusion_instruction());
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
param_no++;
@@ -166,10 +161,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
void HloComputation::Reparent(HloInstruction* instruction) {
instruction->set_parent(this);
- if (is_fusion_computation_ && instruction != root_instruction_) {
- CHECK(root_instruction_->fusion_instruction() != nullptr);
- instruction->SetParentFusion(root_instruction_->fusion_instruction());
- }
}
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
@@ -182,7 +173,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) {
}
if (instruction->opcode() == HloOpcode::kParameter &&
- !is_fusion_computation_) {
+ !IsFusionComputation()) {
return false;
}
@@ -267,7 +258,7 @@ void HloComputation::set_root_instruction(
HloInstruction* new_root_instruction) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
- if (!is_fusion_computation_) {
+ if (!IsFusionComputation()) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape().ShortDebugString()
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index f383a17fb8..576c44a9f3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -56,10 +56,11 @@ class HloComputation {
// Builder class for HloComputation.
class Builder {
public:
- explicit Builder(const string& name, bool is_fusion_computation = false)
+ explicit Builder(const string& name,
+ HloInstruction* fusion_instruction = nullptr)
: name_(name),
last_added_instruction_(nullptr),
- is_fusion_computation_(is_fusion_computation) {}
+ fusion_instruction_(fusion_instruction) {}
// Build and return an HloComputation. The parameter root_instruction
// specifies the already-added instruction to use as the root. If
@@ -78,7 +79,7 @@ class HloComputation {
private:
const string name_;
HloInstruction* last_added_instruction_;
- bool is_fusion_computation_;
+ HloInstruction* fusion_instruction_;
std::vector<std::unique_ptr<HloInstruction>> instructions_;
};
@@ -274,13 +275,18 @@ class HloComputation {
bool HasSideEffect() const;
// Returns if this computation is a fusion computation.
- bool IsFusionComputation() const { return is_fusion_computation_; }
+ bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
+
+ // Returns the owning fusion instruction, or nullptr if this is not a fusion
+ // computation.
+ HloInstruction* FusionInstruction() const { return fusion_instruction_; }
private:
explicit HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction, bool is_fusion_computation = false);
+ HloInstruction* root_instruction,
+ HloInstruction* fusion_instruction = nullptr);
// Internal helper for adding instructions.
HloInstruction* AddInstructionInternal(
@@ -309,8 +315,9 @@ class HloComputation {
string name_;
HloInstruction* root_instruction_;
- // A tag shows if this is a fusion computation.
- bool is_fusion_computation_;
+ // If this computation is a fusion computation, this field points to the
+ // corresponding fusion instruction. Otherwise, this is null.
+ HloInstruction* fusion_instruction_;
// Module containing this computation.
HloModule* parent_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 24a47f80af..dfb111d1d0 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -479,7 +479,7 @@ stylesheet="
// If this edge crosses a fusion cluster boundary, highlight it when the
// cluster is hovered over.
if (from_node->IsFused() &&
- from_node->fusion_instruction()->fused_expression_root() == from_node) {
+ from_node->parent()->root_instruction() == from_node) {
int64 cluster_id = cluster_ids_.at(from_node->parent());
add_hover_css_rule("clust", cluster_id, kBlue);
}
@@ -657,7 +657,7 @@ string HloDotDumper::GetInstructionNodeInlinedConstants(
// Special case: If instr is a parameter to a fusion node, check whether the
// corresponding operand to the fusion node is a constant.
if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
- const HloInstruction* fusion = instr->fusion_instruction();
+ const HloInstruction* fusion = instr->parent()->FusionInstruction();
const HloInstruction* operand = fusion->operand(instr->parameter_number());
if (operand->opcode() != HloOpcode::kConstant) {
return "";
@@ -898,7 +898,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
// expressions are handled specially -- we draw an edge from the corresponding
// operand on the fusion node itself to the parameter.
if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
- const HloInstruction* fusion = instr->fusion_instruction();
+ const HloInstruction* fusion = instr->parent()->FusionInstruction();
add_edge(fusion->operand(instr->parameter_number()), instr,
/*operand_num=*/0);
} else {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 237e745383..3bdb67ba92 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -649,16 +649,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
if (called_computations_.empty()) {
// New fusion instruction. It should not be a multioutput instruction.
CHECK(!add_output);
- auto builder = HloComputation::Builder("fused_computation", true);
+ auto builder = HloComputation::Builder("fused_computation", this);
builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
called_computations_.push_back(
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
- clone->parent_fusion_instruction_ = this;
} else {
clone = fused_instructions_computation()->AddInstruction(
instruction_to_fuse->Clone(/*suffix=*/""));
- clone->parent_fusion_instruction_ = this;
// When add_output is false, instruction_to_fuse is necessarily an operand
// of the fusion instruction. After fusion this will no longer be the case.
// Remove the operand from the operand list and remove its corresponding
@@ -727,12 +725,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// to avoid a double %%.
string param_name =
StrCat(operand->name().substr(1), ".param_", param_no);
- std::unique_ptr<HloInstruction> param_instruction =
- CreateParameter(param_no, operand->shape(), param_name);
-
- param_instruction->parent_fusion_instruction_ = this;
fused_param = fused_instructions_computation()->AddParameter(
- std::move(param_instruction));
+ CreateParameter(param_no, operand->shape(), param_name));
AppendOperand(operand);
}
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
@@ -762,7 +756,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
HloInstruction::CreateTuple(tuple_elements));
fused_instructions_computation()->set_root_instruction(new_root);
shape_ = new_root->shape();
- new_root->parent_fusion_instruction_ = this;
if (fused_root->opcode() == HloOpcode::kTuple) {
TF_CHECK_OK(
fused_instructions_computation()->RemoveInstruction(fused_root));
@@ -839,24 +832,17 @@ bool HloInstruction::HasSideEffect() const {
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
- fused_instructions_computation()->instructions();
- // All instructions owned by this fusion instruction must be fused, and the
- // parent fusion instruction of the fused instructions must be 'this'.
- for (auto& instruction : fused_instructions_) {
- CHECK(instruction->IsFused());
- CHECK_EQ(this, instruction->fusion_instruction());
- CHECK_EQ(fused_instructions_computation(), instruction->parent())
- << instruction->ToString();
- }
+ // The parent fusion instruction of the fusion computation must be 'this'.
+ HloComputation* fused_computation = fused_instructions_computation();
+ CHECK_EQ(this, fused_computation->FusionInstruction());
// Fused root instruction and fused parameters must all be owned by the fusion
- // instruction.
+ // computation.
bool root_owned = false;
const std::vector<HloInstruction*>& fused_parameters_ = fused_parameters();
const HloInstruction* fused_root_ = fused_expression_root();
std::vector<bool> parameter_owned(fused_parameters_.size(), false);
- for (auto& instruction : fused_instructions_) {
+ for (auto& instruction : fused_computation->instructions()) {
if (fused_root_ == instruction.get()) {
CHECK(!root_owned);
root_owned = true;
@@ -877,14 +863,13 @@ void HloInstruction::CheckFusionInstruction() const {
// Fused root must have no users.
CHECK_EQ(0, fused_root_->user_count());
- // All uses of fused instructions must be in the fusion instruction, and every
+ // All uses of fused instructions must be in the fusion computation, and every
// non-root instruction must have at least one use.
- for (auto& instruction : fused_instructions_) {
+ for (auto& instruction : fused_instructions_computation()->instructions()) {
if (instruction.get() != fused_root_) {
CHECK_GT(instruction->user_count(), 0);
for (auto& user : instruction->users()) {
- CHECK(user->IsFused());
- CHECK_EQ(this, user->fusion_instruction());
+ CHECK_EQ(fused_computation, user->parent());
}
}
}
@@ -1173,15 +1158,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
// Create the list of fused parameters by mapping through the cloned,
// fused instructions.
- std::vector<HloInstruction*> new_fused_parameters;
- const std::vector<HloInstruction*>& fused_parameters_ =
- fused_instructions_computation()->parameter_instructions();
-
- for (HloInstruction* old_fused_parameter : fused_parameters_) {
+ for (HloInstruction* old_fused_parameter :
+ fused_instructions_computation()->parameter_instructions()) {
new_fused_instructions.push_back(old_fused_parameter->Clone());
HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
- new_fusion_parameter->parent_fusion_instruction_ = new_instruction.get();
- new_fused_parameters.push_back(new_fusion_parameter);
InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
}
for (auto old_fused_instruction :
@@ -1202,12 +1182,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
old_fused_instruction->shape(), new_operands));
HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
new_fused_instruction->set_parent(parent());
- new_fused_instruction->parent_fusion_instruction_ = new_instruction.get();
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
}
new_instruction->fusion_kind_ = fusion_kind_;
auto computation_builder = HloComputation::Builder(
- fused_instructions_computation()->name() + ".clone", true);
+ fused_instructions_computation()->name() + ".clone",
+ new_instruction.get());
// We iterated the fusion instructions in reverse post order which means
// that we must reverse our new list of fusion instructions.
for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
@@ -1912,9 +1892,7 @@ string HloInstruction::TracingTag() const {
return literal_->u8s_string();
}
-bool HloInstruction::IsFused() const {
- return parent_fusion_instruction_ != nullptr;
-}
+bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
bool HloInstruction::IsFusable() const {
// Instructions which are traced should not be fused.
@@ -1949,11 +1927,6 @@ HloComputation* HloInstruction::fused_instructions_computation() const {
return fused_instructions_computation;
}
-HloInstruction* HloInstruction::fusion_instruction() const {
- CHECK(IsFused());
- return parent_fusion_instruction_;
-}
-
HloInstruction* HloInstruction::fused_expression_root() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
return fused_instructions_computation()->root_instruction();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 923aeb47f0..5688fcc425 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -603,26 +603,21 @@ class HloInstruction {
// instruction.
bool IsFused() const;
+ // Returns the computation for this fused instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ HloComputation* fused_instructions_computation() const;
+
// Returns true if this instruction can be legally fused into a fusion
// instruction.
bool IsFusable() const;
- // Returns the fusion instruction that contains this instruction.
- //
- // Note: only valid if this instruction is fused into a fusion instruction.
- HloInstruction* fusion_instruction() const;
-
// Returns the root instruction of the fused expression contained within this
// fusion instruction.
//
// Precondition: opcode() == HloOpcode::kFusion
HloInstruction* fused_expression_root() const;
- // Returns the computation for this fused instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- HloComputation* fused_instructions_computation() const;
-
// Returns the list of fused instructions inside this fusioninstruction.
//
// Note: although the list itself is const, the instructions contained in the
@@ -898,14 +893,6 @@ class HloInstruction {
// instruction to make it a bitcast.
bool CouldBeBitcast() const;
- // Sets the parent fusion instruction for this instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- void SetParentFusion(HloInstruction* fusion_instruction) {
- CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
- parent_fusion_instruction_ = fusion_instruction;
- }
-
// CHECKs various invariants of a fusion instruction.
void CheckFusionInstruction() const;
@@ -1049,10 +1036,6 @@ class HloInstruction {
// padding of this pad instruction. Only set for pad instructions.
std::unique_ptr<PaddingConfig> padding_config_;
- // If this instruction is fused into a fusion instruction, this field points
- // to the fusion instruction.
- HloInstruction* parent_fusion_instruction_ = nullptr;
-
// The type of the fusion. Used by kFusion only.
FusionKind fusion_kind_;
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 76177462aa..5a4c93b59a 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -91,10 +91,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
string node_name;
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
- if (instruction->IsFused()) {
- node_name = GetNodeNameForInstruction(instruction->fusion_instruction());
+ const HloComputation* computation = instruction->parent();
+ if (computation->IsFusionComputation()) {
+ node_name = GetNodeNameForInstruction(computation->FusionInstruction());
} else {
- node_name = instruction->parent()->name();
+ node_name = computation->name();
if (!instruction->metadata().op_name().empty()) {
// Always make computations contain TF ops but not the other way around.
StrAppend(&node_name, "/", instruction->metadata().op_name());