aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-09 02:47:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 02:51:20 -0700
commit0063183a62f69c2523a3982c70d72e231428fb60 (patch)
tree941a7ecb62b0235db2d7d917040af6d850781799 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent955e356e4c69d3fce4ac2bac5966671e964f9627 (diff)
Fix crash when running with --v=2.
When doing multi-output fusion and using sibling fusion, it can happen that we don't need to clone the 'instruction_to_fuse' argument. Right now, we clone, and then delete the clone again, and at the end of the function try to print the debug string for the clone (which then crashes). Instead, we can simply not generate the clone if it is not needed, and catch this case before printing the debug string. PiperOrigin-RevId: 203733796
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc22
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ed934c689a..c160647f7a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1053,8 +1053,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
} else {
- clone = fused_instructions_computation()->AddInstruction(
- instruction_to_fuse->Clone(/*suffix=*/""));
// When add_output is false, instruction_to_fuse is necessarily an operand
// of the fusion instruction. After fusion this will no longer be the
// case. Remove the operand from the operand list and remove its
@@ -1064,6 +1062,16 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
bool in_operand_list = std::find(operands().begin(), operands().end(),
instruction_to_fuse) != operands().end();
CHECK(add_output || in_operand_list);
+ if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
+ // We assume all uses of a kTuple operation are GTE ops, not another
+ // fusion node. In this case, we don't need to clone
+ // 'instruction_to_fuse'.
+ CHECK(!in_operand_list);
+ clone = instruction_to_fuse;
+ } else {
+ clone = fused_instructions_computation()->AddInstruction(
+ instruction_to_fuse->Clone(/*suffix=*/""));
+ }
const std::vector<HloInstruction*>& fused_parameters =
fused_instructions_computation()->parameter_instructions();
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
@@ -1160,9 +1168,10 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
}
int64 index = tuple_elements.size();
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
- index -= instruction_to_fuse->operand_count();
+ CHECK_EQ(clone, instruction_to_fuse);
+ index -= clone->operand_count();
std::vector<HloInstruction*> to_be_removed;
- for (auto old_gte : instruction_to_fuse->users()) {
+ for (auto old_gte : clone->users()) {
CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
int64 old_tuple_index = old_gte->tuple_index();
HloInstruction* new_gte =
@@ -1174,7 +1183,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
for (auto old_gte : to_be_removed) {
TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
}
- TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
} else {
HloInstruction* new_gte =
parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
@@ -1183,7 +1191,9 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
}
}
- VLOG(2) << "New clone:\n" << clone->ToString();
+ if (clone != instruction_to_fuse) {
+ VLOG(2) << "New clone:\n" << clone->ToString();
+ }
return clone;
}