aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-11-02 19:12:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 19:18:02 -0700
commit58143d36c06c2b027ae7f9f4d51dadcdc1c66b74 (patch)
treed206d37c80b4587346ede5a0e00c88afee9a227e /tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
parent02608eadc34e5a606a95375ba078879145a55b7e (diff)
[XLA] Add dead tuple elem removal to WhileLoopSimplifier.
Specifically, if a while loop has tuple element that - is not used by the while condition, and - is not used by the while body, except to pass it along to the next iteration of the loop, then we can reshape the while loop's computations to eliminate this tuple element. PiperOrigin-RevId: 174413683
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc296
1 files changed, 272 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 609a5b3885..8e1a2dcde1 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -28,11 +28,16 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
public:
// Makes a computation that contains a loop that runs num_iters times.
- HloComputation* MakeSimpleLoop(HloModule* module, int num_iters);
+ HloComputation* MakeSimpleLoop(int num_iters, HloModule* module);
+
+ // Makes a computation which has one parameter, of the given shape, and always
+ // returns PRED[]{true}. This is useful as a dummy loop condition.
+ HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
+ HloModule* module);
};
-HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(HloModule* module,
- int num_iters) {
+HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters,
+ HloModule* module) {
HloComputation::Builder builder(TestName());
auto loop_iter_init = builder.AddInstruction(
@@ -89,38 +94,44 @@ HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(HloModule* module,
return module->AddEntryComputation(builder.Build());
}
+HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation(
+ const Shape& param_shape, HloModule* module) {
+ HloComputation::Builder builder(TestName() + ".always_true");
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(0, param_shape, "param"));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ return module->AddEmbeddedComputation(builder.Build());
+}
+
TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) {
- HloModule module(TestName());
- HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/0);
- ASSERT_TRUE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module());
+ ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Tuple(op::Constant(), op::Constant()));
}
TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) {
- HloModule module(TestName());
- HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1);
- ASSERT_TRUE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
+ ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Tuple(op::Add(), op::Multiply()));
}
TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) {
- HloModule module(TestName());
- MakeSimpleLoop(&module, /*num_iters=*/2);
- EXPECT_FALSE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ MakeSimpleLoop(/*num_iters=*/2, &module());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) {
- HloModule module(TestName());
- HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1);
+ HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* true_op = while_op->while_body()->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
TF_ASSERT_OK(true_op->AddControlDependencyTo(
while_op->while_body()->root_instruction()));
- ASSERT_TRUE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction()->control_predecessors(),
ElementsAre(op::Constant()))
<< computation->ToString();
@@ -129,8 +140,7 @@ TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) {
// Loops that contain send/recv nodes can't be simplified; the loop structure
// around send/recv nodes must be preserved.
TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
- HloModule module(TestName());
- HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1);
+ HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
@@ -138,19 +148,18 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
while_body->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
/*channel_id=*/0));
- EXPECT_FALSE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
- HloModule module(TestName());
- HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1);
+ HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
while_body->AddInstruction(
HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
/*channel_id=*/0));
- EXPECT_FALSE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
// The limitation on not being able to simplify loops that contain infeeds (and
@@ -158,14 +167,253 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
// fact that our infrastructure sees simplifying such a loop as tantamount to
// removing the non-removable instruction.
TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
- HloModule module(TestName());
- HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1);
+ HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module());
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
while_body->AddInstruction(
HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
- EXPECT_FALSE(WhileLoopSimplifier().Run(&module).ValueOrDie());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
+// Check that we don't crash when given a loop whose shape is not a tuple.
+TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) {
+ HloComputation::Builder builder(TestName());
+ auto loop_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+
+ HloComputation* condition;
+ {
+ HloComputation::Builder cond_builder(TestName() + ".condition");
+ auto param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param,
+ cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(100)))));
+ condition = module().AddEmbeddedComputation(cond_builder.Build());
+ }
+
+ HloComputation* body;
+ {
+ HloComputation::Builder body_builder(TestName() + ".body");
+ auto param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param,
+ body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(-1)))));
+ body = module().AddEmbeddedComputation(body_builder.Build());
+ }
+
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_init->shape(), condition, body, loop_init));
+
+ module().AddEntryComputation(builder.Build());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
+// Construct a loop where we swap the tuple elements in each iteration.
+// Although the tuple elements aren't used in the loop, we don't eliminate them,
+// because the swapping side-effect is visible to users of the loop.
+TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) {
+ HloComputation::Builder builder(TestName());
+ auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))),
+ }));
+
+ HloComputation* condition =
+ MakeAlwaysTrueComputation(loop_init->shape(), &module());
+ HloComputation* body;
+ {
+ HloComputation::Builder body_builder(TestName() + ".body");
+ auto param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
+ auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ body_builder.AddInstruction(HloInstruction::CreateTuple({
+ body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)),
+ body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)),
+ }));
+ body = module().AddEmbeddedComputation(body_builder.Build());
+ }
+
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_init->shape(), condition, body, loop_init));
+
+ module().AddEntryComputation(builder.Build());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
+// Construct a loop where we assign a constant to tuple element 0 in each
+// iteration. We can't eliminate tuple element 0, even though we never use its
+// value.
+TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) {
+ HloComputation::Builder builder(TestName());
+ auto loop_init = builder.AddInstruction(
+ HloInstruction::CreateTuple({builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)))}));
+
+ HloComputation* condition =
+ MakeAlwaysTrueComputation(loop_init->shape(), &module());
+ HloComputation* body;
+ {
+ HloComputation::Builder body_builder(TestName() + ".body");
+ body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
+ body_builder.AddInstruction(HloInstruction::CreateTuple({
+ body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))),
+ }));
+ body = module().AddEmbeddedComputation(body_builder.Build());
+ }
+
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_init->shape(), condition, body, loop_init));
+
+ module().AddEntryComputation(builder.Build());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
+// Nothing to simplify in a while loop whose tuple has 0 elements.
+TEST_F(WhileLoopSimplifierTest, EmptyTuple) {
+ HloComputation::Builder builder(TestName());
+ auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({}));
+
+ HloComputation* condition =
+ MakeAlwaysTrueComputation(loop_init->shape(), &module());
+ HloComputation* body;
+ {
+ HloComputation::Builder body_builder(TestName() + ".body");
+ body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var"));
+ body_builder.AddInstruction(HloInstruction::CreateTuple({}));
+ body = module().AddEmbeddedComputation(body_builder.Build());
+ }
+
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_init->shape(), condition, body, loop_init));
+ module().AddEntryComputation(builder.Build());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
+// While loop where one tuple element is used twice in the body, and thus can't
+// be simplified away.
+TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) {
+ HloComputation::Builder builder(TestName());
+ auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))),
+ }));
+
+ HloComputation* condition =
+ MakeAlwaysTrueComputation(loop_init->shape(), &module());
+
+ auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ HloComputation* body;
+ {
+ HloComputation::Builder body_builder(TestName() + ".body");
+ auto* param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_init->shape(), "param0"));
+ auto* gte0 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0));
+ // get0 is used twice in the loop body's tuple.
+ body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0}));
+ body = module().AddEmbeddedComputation(body_builder.Build());
+ }
+
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_init->shape(), condition, body, loop_init));
+ module().AddEntryComputation(builder.Build());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
+// This while loop has three tuple elements. Element 0 is unused and should be
+// removed. Element 1 is used by the loop body, and element 2 is used by the
+// loop condition; these two should stay.
+TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) {
+ HloComputation::Builder builder(TestName());
+ auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
+ }));
+ auto loop_shape = loop_init->shape();
+ auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
+
+ HloComputation* condition;
+ {
+ HloComputation::Builder cond_builder(TestName() + ".loop_condition");
+ auto param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_shape, "param0"));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq,
+ cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))),
+ cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ scalar_s32, param, /*index=*/2))));
+ condition = module().AddEmbeddedComputation(cond_builder.Build());
+ }
+
+ HloComputation* body;
+ {
+ HloComputation::Builder body_builder(TestName() + ".body");
+ auto* param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_shape, "loop_var"));
+
+ auto* tuple0 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0));
+ auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_s32, HloOpcode::kAdd,
+ body_builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ scalar_s32, param, /*index=*/1)),
+ body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)))));
+ auto* tuple2 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2));
+ body_builder.AddInstruction(
+ HloInstruction::CreateTuple({tuple0, tuple1, tuple2}));
+
+ body = module().AddEmbeddedComputation(body_builder.Build());
+ }
+
+ auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_init->shape(), condition, body, loop_init));
+
+ module().AddEntryComputation(builder.Build());
+ EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+
+ // We leave most of the checking to HloVerifiedTestBase, which runs the
+ // verifier on module() at the end of this test.
+ HloInstruction* new_while_op = *std::find_if(
+ module().entry_computation()->instructions().begin(),
+ module().entry_computation()->instructions().end(),
+ [&](const HloInstruction* instr) {
+ return instr != while_op && instr->opcode() == HloOpcode::kWhile;
+ });
+ EXPECT_TRUE(
+ ShapeUtil::Equal(new_while_op->shape(),
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32})))
+ << ShapeUtil::HumanString(new_while_op->shape());
+ EXPECT_THAT(
+ new_while_op->while_body()->root_instruction(),
+ op::Tuple(
+ op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0),
+ op::Constant()),
+ op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
+
+ EXPECT_THAT(new_while_op->while_condition()->root_instruction(),
+ op::Eq(op::Constant(),
+ op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
}
} // namespace