aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-30 13:19:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 13:23:54 -0700
commite565d1f1fced69789feb10f1ea1241157ec95f93 (patch)
tree0351aafe956f71c30f6aa27da7917bc10a0afdf0
parent79603e4dd92b182ca0689da97d4d992d1d2b1316 (diff)
[XLA] Refactor parent-fusion-instruction pointer into HloComputation, not HloInstruction.
Presently, each instruction inside a fusion computation contains a pointer to the fusion instruction that contains the computation, which is redundant since this is common across the entire computation. This leads to lots of places where this pointer must be set when adding an instruction to the fusion computation (and bugs such as b/65177535 when one is missed), as well as code to check that it's set correctly. In addition, this is simply unnecessary data bloat. Moreover, the computation itself does not contain a pointer to the fusion instruction that references it, which leads to odd circumlocutions in the HloComputation code that retrieve the fusion instruction from the computation's root instruction. Thus, this CL moves this pointer into the HloComputation class (replacing the is_fusion_computation_ bool value), and refactor the uses as necessary. PiperOrigin-RevId: 167039280
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h27
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc7
9 files changed, 56 insertions, 103 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 84bdd5acac..b02138325e 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -616,7 +616,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
auto random_value = [hlo]() {
const HloModule* module =
- hlo->IsFused() ? hlo->fusion_instruction()->parent()->parent()
+ hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent()
: hlo->parent()->parent();
return module->RandomNew64();
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index e7d32a4ae1..749badf3f2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -894,7 +894,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
&ir_builder_);
const HloInstruction* output =
- reduce->IsFused() ? reduce->fusion_instruction() : reduce;
+ reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_,
"output_element_address");
@@ -1142,7 +1142,7 @@ Status IrEmitterUnnested::EmitRowReduction(
}
const HloInstruction* output =
- reduce->IsFused() ? reduce->fusion_instruction() : reduce;
+ reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
// Emit an atomic operation that accumulates the partial reduction result of
// lane 0 (which holds the partially accumulated result for its warp) to the
@@ -1913,10 +1913,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
- // const HloInstruction* root = hlo.fused_expression_root();
- llvm_ir::EmitTuple(
- GetIrArray(*hlo.fused_expression_root()->fusion_instruction()),
- tuple_operand_ptrs, &ir_builder_);
+ llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
index cecbb01ff8..ccdd171759 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -308,7 +308,7 @@ class WhileConditionComputationMatcher : public MatcherBase {
GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions));
CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode());
CHECK(gte_fusion_param0->IsFused());
- if (gte_fusion_param0->fusion_instruction()->operand(
+ if (gte_fusion_param0->parent()->FusionInstruction()->operand(
gte_fusion_param0->parameter_number()) !=
computation_->parameter_instruction(0)) {
return InvalidArgument("Could not match fusion param: %s",
@@ -469,7 +469,8 @@ class WhileBodyComputationMatcher : public MatcherBase {
// Fusion parameter: lookup and compare with associated fusion operand.
CHECK_EQ(HloOpcode::kParameter, inst->opcode());
CHECK(inst->IsFused());
- if (inst->fusion_instruction()->operand(inst->parameter_number()) !=
+ if (inst->parent()->FusionInstruction()->operand(
+ inst->parameter_number()) !=
computation_->parameter_instruction(0)) {
return InvalidArgument("Could not match fusion param: %s",
inst->name().c_str());
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c030ceb72f..2d07784619 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -58,16 +58,16 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
CHECK_NE(nullptr, root);
return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
- root, is_fusion_computation_));
+ root, fusion_instruction_));
}
HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction, bool is_fusion_computation)
+ HloInstruction* root_instruction, HloInstruction* fusion_instruction)
: name_(name),
root_instruction_(root_instruction),
- is_fusion_computation_(is_fusion_computation) {
+ fusion_instruction_(fusion_instruction) {
param_instructions_.resize(parameter_count, nullptr);
bool root_found = false;
for (auto& instruction : *instructions) {
@@ -112,11 +112,8 @@ HloInstruction* HloComputation::AddInstructionInternal(
HloInstruction* HloComputation::AddParameter(
std::unique_ptr<HloInstruction> instruction) {
CHECK(instruction->opcode() == HloOpcode::kParameter);
- CHECK(is_fusion_computation_);
- CHECK(root_instruction_->fusion_instruction() != nullptr);
- instruction->SetParentFusion(root_instruction_->fusion_instruction());
- CHECK(root_instruction_->fusion_instruction()->operand_count() ==
- param_instructions_.size());
+ CHECK(IsFusionComputation());
+ CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
instruction->set_parent(this);
param_instructions_.push_back(instruction.get());
AddInstructionInternal(std::move(instruction));
@@ -126,8 +123,7 @@ HloInstruction* HloComputation::AddParameter(
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
- CHECK(is_fusion_computation_);
- CHECK(root_instruction_->fusion_instruction() != nullptr);
+ CHECK(IsFusionComputation());
HloInstruction* param_instruction = param_instructions_[param_no];
auto param_instruction_iterator = param_instructions_.begin() + param_no;
param_instructions_.erase(param_instruction_iterator);
@@ -155,7 +151,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
- new_instr->SetParentFusion(root_instruction_->fusion_instruction());
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
param_no++;
@@ -166,10 +161,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
void HloComputation::Reparent(HloInstruction* instruction) {
instruction->set_parent(this);
- if (is_fusion_computation_ && instruction != root_instruction_) {
- CHECK(root_instruction_->fusion_instruction() != nullptr);
- instruction->SetParentFusion(root_instruction_->fusion_instruction());
- }
}
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
@@ -182,7 +173,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) {
}
if (instruction->opcode() == HloOpcode::kParameter &&
- !is_fusion_computation_) {
+ !IsFusionComputation()) {
return false;
}
@@ -267,7 +258,7 @@ void HloComputation::set_root_instruction(
HloInstruction* new_root_instruction) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
- if (!is_fusion_computation_) {
+ if (!IsFusionComputation()) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape().ShortDebugString()
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index f383a17fb8..576c44a9f3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -56,10 +56,11 @@ class HloComputation {
// Builder class for HloComputation.
class Builder {
public:
- explicit Builder(const string& name, bool is_fusion_computation = false)
+ explicit Builder(const string& name,
+ HloInstruction* fusion_instruction = nullptr)
: name_(name),
last_added_instruction_(nullptr),
- is_fusion_computation_(is_fusion_computation) {}
+ fusion_instruction_(fusion_instruction) {}
// Build and return an HloComputation. The parameter root_instruction
// specifies the already-added instruction to use as the root. If
@@ -78,7 +79,7 @@ class HloComputation {
private:
const string name_;
HloInstruction* last_added_instruction_;
- bool is_fusion_computation_;
+ HloInstruction* fusion_instruction_;
std::vector<std::unique_ptr<HloInstruction>> instructions_;
};
@@ -274,13 +275,18 @@ class HloComputation {
bool HasSideEffect() const;
// Returns if this computation is a fusion computation.
- bool IsFusionComputation() const { return is_fusion_computation_; }
+ bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
+
+ // Returns the owning fusion instruction, or nullptr if this is not a fusion
+ // computation.
+ HloInstruction* FusionInstruction() const { return fusion_instruction_; }
private:
explicit HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction, bool is_fusion_computation = false);
+ HloInstruction* root_instruction,
+ HloInstruction* fusion_instruction = nullptr);
// Internal helper for adding instructions.
HloInstruction* AddInstructionInternal(
@@ -309,8 +315,9 @@ class HloComputation {
string name_;
HloInstruction* root_instruction_;
- // A tag shows if this is a fusion computation.
- bool is_fusion_computation_;
+ // If this computation is a fusion computation, this field points to the
+ // corresponding fusion instruction. Otherwise, this is null.
+ HloInstruction* fusion_instruction_;
// Module containing this computation.
HloModule* parent_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 24a47f80af..dfb111d1d0 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -479,7 +479,7 @@ stylesheet="
// If this edge crosses a fusion cluster boundary, highlight it when the
// cluster is hovered over.
if (from_node->IsFused() &&
- from_node->fusion_instruction()->fused_expression_root() == from_node) {
+ from_node->parent()->root_instruction() == from_node) {
int64 cluster_id = cluster_ids_.at(from_node->parent());
add_hover_css_rule("clust", cluster_id, kBlue);
}
@@ -657,7 +657,7 @@ string HloDotDumper::GetInstructionNodeInlinedConstants(
// Special case: If instr is a parameter to a fusion node, check whether the
// corresponding operand to the fusion node is a constant.
if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
- const HloInstruction* fusion = instr->fusion_instruction();
+ const HloInstruction* fusion = instr->parent()->FusionInstruction();
const HloInstruction* operand = fusion->operand(instr->parameter_number());
if (operand->opcode() != HloOpcode::kConstant) {
return "";
@@ -898,7 +898,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
// expressions are handled specially -- we draw an edge from the corresponding
// operand on the fusion node itself to the parameter.
if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
- const HloInstruction* fusion = instr->fusion_instruction();
+ const HloInstruction* fusion = instr->parent()->FusionInstruction();
add_edge(fusion->operand(instr->parameter_number()), instr,
/*operand_num=*/0);
} else {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 237e745383..3bdb67ba92 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -649,16 +649,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
if (called_computations_.empty()) {
// New fusion instruction. It should not be a multioutput instruction.
CHECK(!add_output);
- auto builder = HloComputation::Builder("fused_computation", true);
+ auto builder = HloComputation::Builder("fused_computation", this);
builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
called_computations_.push_back(
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
- clone->parent_fusion_instruction_ = this;
} else {
clone = fused_instructions_computation()->AddInstruction(
instruction_to_fuse->Clone(/*suffix=*/""));
- clone->parent_fusion_instruction_ = this;
// When add_output is false, instruction_to_fuse is necessarily an operand
// of the fusion instruction. After fusion this will no longer be the case.
// Remove the operand from the operand list and remove its corresponding
@@ -727,12 +725,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// to avoid a double %%.
string param_name =
StrCat(operand->name().substr(1), ".param_", param_no);
- std::unique_ptr<HloInstruction> param_instruction =
- CreateParameter(param_no, operand->shape(), param_name);
-
- param_instruction->parent_fusion_instruction_ = this;
fused_param = fused_instructions_computation()->AddParameter(
- std::move(param_instruction));
+ CreateParameter(param_no, operand->shape(), param_name));
AppendOperand(operand);
}
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
@@ -762,7 +756,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
HloInstruction::CreateTuple(tuple_elements));
fused_instructions_computation()->set_root_instruction(new_root);
shape_ = new_root->shape();
- new_root->parent_fusion_instruction_ = this;
if (fused_root->opcode() == HloOpcode::kTuple) {
TF_CHECK_OK(
fused_instructions_computation()->RemoveInstruction(fused_root));
@@ -839,24 +832,17 @@ bool HloInstruction::HasSideEffect() const {
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
- fused_instructions_computation()->instructions();
- // All instructions owned by this fusion instruction must be fused, and the
- // parent fusion instruction of the fused instructions must be 'this'.
- for (auto& instruction : fused_instructions_) {
- CHECK(instruction->IsFused());
- CHECK_EQ(this, instruction->fusion_instruction());
- CHECK_EQ(fused_instructions_computation(), instruction->parent())
- << instruction->ToString();
- }
+ // The parent fusion instruction of the fusion computation must be 'this'.
+ HloComputation* fused_computation = fused_instructions_computation();
+ CHECK_EQ(this, fused_computation->FusionInstruction());
// Fused root instruction and fused parameters must all be owned by the fusion
- // instruction.
+ // computation.
bool root_owned = false;
const std::vector<HloInstruction*>& fused_parameters_ = fused_parameters();
const HloInstruction* fused_root_ = fused_expression_root();
std::vector<bool> parameter_owned(fused_parameters_.size(), false);
- for (auto& instruction : fused_instructions_) {
+ for (auto& instruction : fused_computation->instructions()) {
if (fused_root_ == instruction.get()) {
CHECK(!root_owned);
root_owned = true;
@@ -877,14 +863,13 @@ void HloInstruction::CheckFusionInstruction() const {
// Fused root must have no users.
CHECK_EQ(0, fused_root_->user_count());
- // All uses of fused instructions must be in the fusion instruction, and every
+ // All uses of fused instructions must be in the fusion computation, and every
// non-root instruction must have at least one use.
- for (auto& instruction : fused_instructions_) {
+ for (auto& instruction : fused_instructions_computation()->instructions()) {
if (instruction.get() != fused_root_) {
CHECK_GT(instruction->user_count(), 0);
for (auto& user : instruction->users()) {
- CHECK(user->IsFused());
- CHECK_EQ(this, user->fusion_instruction());
+ CHECK_EQ(fused_computation, user->parent());
}
}
}
@@ -1173,15 +1158,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
// Create the list of fused parameters by mapping through the cloned,
// fused instructions.
- std::vector<HloInstruction*> new_fused_parameters;
- const std::vector<HloInstruction*>& fused_parameters_ =
- fused_instructions_computation()->parameter_instructions();
-
- for (HloInstruction* old_fused_parameter : fused_parameters_) {
+ for (HloInstruction* old_fused_parameter :
+ fused_instructions_computation()->parameter_instructions()) {
new_fused_instructions.push_back(old_fused_parameter->Clone());
HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
- new_fusion_parameter->parent_fusion_instruction_ = new_instruction.get();
- new_fused_parameters.push_back(new_fusion_parameter);
InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
}
for (auto old_fused_instruction :
@@ -1202,12 +1182,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
old_fused_instruction->shape(), new_operands));
HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
new_fused_instruction->set_parent(parent());
- new_fused_instruction->parent_fusion_instruction_ = new_instruction.get();
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
}
new_instruction->fusion_kind_ = fusion_kind_;
auto computation_builder = HloComputation::Builder(
- fused_instructions_computation()->name() + ".clone", true);
+ fused_instructions_computation()->name() + ".clone",
+ new_instruction.get());
// We iterated the fusion instructions in reverse post order which means
// that we must reverse our new list of fusion instructions.
for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
@@ -1912,9 +1892,7 @@ string HloInstruction::TracingTag() const {
return literal_->u8s_string();
}
-bool HloInstruction::IsFused() const {
- return parent_fusion_instruction_ != nullptr;
-}
+bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
bool HloInstruction::IsFusable() const {
// Instructions which are traced should not be fused.
@@ -1949,11 +1927,6 @@ HloComputation* HloInstruction::fused_instructions_computation() const {
return fused_instructions_computation;
}
-HloInstruction* HloInstruction::fusion_instruction() const {
- CHECK(IsFused());
- return parent_fusion_instruction_;
-}
-
HloInstruction* HloInstruction::fused_expression_root() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
return fused_instructions_computation()->root_instruction();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 923aeb47f0..5688fcc425 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -603,26 +603,21 @@ class HloInstruction {
// instruction.
bool IsFused() const;
+ // Returns the computation for this fused instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ HloComputation* fused_instructions_computation() const;
+
// Returns true if this instruction can be legally fused into a fusion
// instruction.
bool IsFusable() const;
- // Returns the fusion instruction that contains this instruction.
- //
- // Note: only valid if this instruction is fused into a fusion instruction.
- HloInstruction* fusion_instruction() const;
-
// Returns the root instruction of the fused expression contained within this
// fusion instruction.
//
// Precondition: opcode() == HloOpcode::kFusion
HloInstruction* fused_expression_root() const;
- // Returns the computation for this fused instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- HloComputation* fused_instructions_computation() const;
-
// Returns the list of fused instructions inside this fusioninstruction.
//
// Note: although the list itself is const, the instructions contained in the
@@ -898,14 +893,6 @@ class HloInstruction {
// instruction to make it a bitcast.
bool CouldBeBitcast() const;
- // Sets the parent fusion instruction for this instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- void SetParentFusion(HloInstruction* fusion_instruction) {
- CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
- parent_fusion_instruction_ = fusion_instruction;
- }
-
// CHECKs various invariants of a fusion instruction.
void CheckFusionInstruction() const;
@@ -1049,10 +1036,6 @@ class HloInstruction {
// padding of this pad instruction. Only set for pad instructions.
std::unique_ptr<PaddingConfig> padding_config_;
- // If this instruction is fused into a fusion instruction, this field points
- // to the fusion instruction.
- HloInstruction* parent_fusion_instruction_ = nullptr;
-
// The type of the fusion. Used by kFusion only.
FusionKind fusion_kind_;
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 76177462aa..5a4c93b59a 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -91,10 +91,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
string node_name;
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
- if (instruction->IsFused()) {
- node_name = GetNodeNameForInstruction(instruction->fusion_instruction());
+ const HloComputation* computation = instruction->parent();
+ if (computation->IsFusionComputation()) {
+ node_name = GetNodeNameForInstruction(computation->FusionInstruction());
} else {
- node_name = instruction->parent()->name();
+ node_name = computation->name();
if (!instruction->metadata().op_name().empty()) {
// Always make computations contain TF ops but not the other way around.
StrAppend(&node_name, "/", instruction->metadata().op_name());