aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 10:33:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 10:38:50 -0800
commit87dab2d8289750c9d34f26d7d5fb18475dff985b (patch)
tree17d1d0de205110553f015946fdd45a488ad325d7 /tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
parent58d5fa05a67b65979708f541336c2c11bfed978e (diff)
Automated g4 rollback of changelist 188397087
PiperOrigin-RevId: 188503184
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc96
1 files changed, 1 insertions, 95 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 396f942dc0..cbea3e3cf2 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -30,11 +30,6 @@ class WhileLoopSimplifierTest : public HloVerifiedTestBase {
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
-
- // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to
- // the loop-condition through an element of a tuple which is the
- // loop-condition parameter.
- void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters);
};
void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
@@ -71,45 +66,6 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
ParseAndVerifyModule(hlo_string.c_str());
}
-void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
- int num_iters) {
- string hlo_string_template = R"(
- HloModule SimpleLoopWithIndirectLoopBound
- SimpleLoopWithIndirectLoopBound.body {
- loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0)
- get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
- constant.1 = s32[] constant(1)
- add = s32[] add(get-tuple-element.1, constant.1)
- get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
- multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
- limit = s32[] get-tuple-element(loop_var.1), index=2
- ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit)
- }
- SimpleLoopWithIndirectLoopBound.condition {
- loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
- get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
- get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
- ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
- }
- ENTRY SimpleLoopWithIndirectLoopBound {
- constant.3 = s32[] constant(42)
- constant.4 = s32[3]{0} constant({0, 1, 2})
- constant.2 = s32[] constant({{LOOP_BOUND}})
- tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4,
- constant.2)
- ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1),
- condition=SimpleLoopWithIndirectLoopBound.condition,
- body=SimpleLoopWithIndirectLoopBound.body
- }
- )";
-
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
- ParseAndVerifyModule(hlo_string.c_str());
-}
-
TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/0);
HloModule* the_module = &module();
@@ -118,15 +74,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
op::Tuple(op::Constant(), op::Constant()));
}
-TEST_F(WhileLoopSimplifierTest,
- LoopWithZeroIterationTupleElementLoopBoundSimplified) {
- MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0);
- HloModule* the_module = &module();
- ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
- EXPECT_THAT(the_module->entry_computation()->root_instruction(),
- op::Tuple(op::Constant(), op::Constant(), op::Constant()));
-}
-
TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/1);
HloModule* the_module = &module();
@@ -135,15 +82,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
op::Tuple(op::Add(), op::Multiply()));
}
-TEST_F(WhileLoopSimplifierTest,
- LoopWithOneIterationTupleELementLoopBoundSimplified) {
- MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1);
- HloModule* the_module = &module();
- ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
- EXPECT_THAT(the_module->entry_computation()->root_instruction(),
- op::Tuple(op::Add(), op::Multiply(), op::Constant()));
-}
-
TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/2);
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
@@ -426,6 +364,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
HloModule BodyHasNonTupleRoot
BodyHasNonTupleRoot.passthrough {
ROOT param = (s32[], s32[]) parameter(0)
+ get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param), index=1
}
BodyHasNonTupleRoot.always_true {
param.1 = (s32[], s32[]) parameter(0)
@@ -443,38 +382,5 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
-TEST_F(WhileLoopSimplifierTest,
- LoopWithNonTupleBodyRootInstructionNotSimplified) {
- const string hlo_string = R"(
- HloModule SimpleLoop
- SimpleLoop.body {
- loop_var.1 = (s32[], s32[3]{0}) parameter(0)
- get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
- constant.1 = s32[] constant(1)
- add = s32[] add(get-tuple-element.1, constant.1)
- get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
- multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
- ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply),
- custom_call_target="x"
- }
- SimpleLoop.condition {
- loop_var.2 = (s32[], s32[3]{0}) parameter(0)
- get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
- constant.2 = s32[] constant(44)
- ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
- }
- ENTRY SimpleLoop {
- constant.3 = s32[] constant(42)
- constant.4 = s32[3]{0} constant({0, 1, 2})
- tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
- ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
- SimpleLoop.condition, body=SimpleLoop.body
- }
- )";
-
- ParseAndVerifyModule(hlo_string.c_str());
- EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
-}
-
} // namespace
} // namespace xla