diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ccc2c38749..1100fd1edc 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -92,13 +92,22 @@ HloInstruction* HloComputation::AddInstructionInternal( // Generate a unique name for the instruction. instruction->set_name( instruction_name_uniquer_.GetUniqueName(instruction->name())); - instruction->set_parent(this); + Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); return pinst; } +void HloComputation::Reparent(HloInstruction* instruction) { + instruction->set_parent(this); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& instruction : instruction->fused_instructions()) { + Reparent(instruction.get()); + } + } +} + /* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) { return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace || |