aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-01-05 10:55:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 10:59:03 -0800
commit604272261e37d0bad102d4e4cbf9e90c66489222 (patch)
tree90bf3fa17cf2a5451cc84916411d09ad4966f7c0 /tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
parent003ae0eab1d9a95a85e07887cdecc3e8c836b6e3 (diff)
[TF:XLA] Correctly simplify while loops with a non-tuple root body
Explicitly bail out if the root instruction of a loop body isn't a tuple() instruction. PiperOrigin-RevId: 180948724
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index d99b31dc00..c5183f8d3a 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -418,5 +418,32 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) {
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
}
+TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) {
+ auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+
+ HloComputation* while_body = [&]() {
+ HloComputation::Builder builder(TestName() + ".passthrough");
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, while_shape, "param"));
+ HloComputation* result = module().AddEmbeddedComputation(builder.Build());
+
+ result->AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ return result;
+ }();
+
+ HloComputation::Builder builder(TestName());
+ auto* init_value = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
+ while_body, init_value));
+ module().AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
+ WhileLoopSimplifier{}.Run(&module()));
+ EXPECT_FALSE(simplified_loop);
+}
+
} // namespace
} // namespace xla