diff options
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier_test.cc | 27 |
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index fb0e6f7ce0..1073cc7700 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -306,6 +306,11 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // Bail if param0 of while_cond or while_body has users which aren't of type diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d99b31dc00..c5183f8d3a 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -418,5 +418,32 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } +TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopSimplifier{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla |