aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ccc2c38749..1100fd1edc 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -92,13 +92,22 @@ HloInstruction* HloComputation::AddInstructionInternal(
// Generate a unique name for the instruction.
instruction->set_name(
instruction_name_uniquer_.GetUniqueName(instruction->name()));
- instruction->set_parent(this);
+ Reparent(instruction.get());
HloInstruction* pinst = instruction.get();
instruction_iterators_[pinst] =
instructions_.insert(instructions_.end(), std::move(instruction));
return pinst;
}
+void HloComputation::Reparent(HloInstruction* instruction) {
+ instruction->set_parent(this);
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ for (auto& instruction : instruction->fused_instructions()) {
+ Reparent(instruction.get());
+ }
+ }
+}
+
/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) {
return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv ||
opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace ||