diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index ed934c689a..c160647f7a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1053,8 +1053,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); // When add_output is false, instruction_to_fuse is necessarily an operand // of the fusion instruction. After fusion this will no longer be the // case. Remove the operand from the operand list and remove its @@ -1064,6 +1062,16 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( bool in_operand_list = std::find(operands().begin(), operands().end(), instruction_to_fuse) != operands().end(); CHECK(add_output || in_operand_list); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + // We assume all uses of a kTuple operation are GTE ops, not another + // fusion node. In this case, we don't need to clone + // 'instruction_to_fuse'. + CHECK(!in_operand_list); + clone = instruction_to_fuse; + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + } const std::vector<HloInstruction*>& fused_parameters = fused_instructions_computation()->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { @@ -1160,9 +1168,10 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); + CHECK_EQ(clone, instruction_to_fuse); + index -= clone->operand_count(); std::vector<HloInstruction*> to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { + for (auto old_gte : clone->users()) { CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); int64 old_tuple_index = old_gte->tuple_index(); HloInstruction* new_gte = @@ -1174,7 +1183,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( for (auto old_gte : to_be_removed) { TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); } else { HloInstruction* new_gte = parent()->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -1183,7 +1191,9 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( } } - VLOG(2) << "New clone:\n" << clone->ToString(); + if (clone != instruction_to_fuse) { + VLOG(2) << "New clone:\n" << clone->ToString(); + } return clone; } |