diff options
author | 2018-03-09 10:33:28 -0800 | |
---|---|---|
committer | 2018-03-09 10:38:50 -0800 | |
commit | 87dab2d8289750c9d34f26d7d5fb18475dff985b (patch) | |
tree | 17d1d0de205110553f015946fdd45a488ad325d7 /tensorflow/compiler/xla/service/while_loop_simplifier_test.cc | |
parent | 58d5fa05a67b65979708f541336c2c11bfed978e (diff) |
Automated g4 rollback of changelist 188397087
PiperOrigin-RevId: 188503184
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_simplifier_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier_test.cc | 96 |
1 files changed, 1 insertions, 95 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 396f942dc0..cbea3e3cf2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -30,11 +30,6 @@ class WhileLoopSimplifierTest : public HloVerifiedTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); - - // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to - // the loop-condition through an element of a tuple which is the - // loop-condition parameter. - void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { @@ -71,45 +66,6 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { ParseAndVerifyModule(hlo_string.c_str()); } -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( - int num_iters) { - string hlo_string_template = R"( - HloModule SimpleLoopWithIndirectLoopBound - SimpleLoopWithIndirectLoopBound.body { - loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - constant.1 = s32[] constant(1) - add = s32[] add(get-tuple-element.1, constant.1) - get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 - multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) - limit = s32[] get-tuple-element(loop_var.1), index=2 - ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit) - } - SimpleLoopWithIndirectLoopBound.condition { - loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) - get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 - get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 - ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) - } - ENTRY SimpleLoopWithIndirectLoopBound { - constant.3 = s32[] constant(42) - constant.4 = s32[3]{0} constant({0, 1, 2}) - constant.2 = s32[] constant({{LOOP_BOUND}}) - tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4, - constant.2) - ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1), - condition=SimpleLoopWithIndirectLoopBound.condition, - body=SimpleLoopWithIndirectLoopBound.body - } - )"; - - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); - ParseAndVerifyModule(hlo_string.c_str()); -} - TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { MakeModuleWithSimpleLoop(/*num_iters=*/0); HloModule* the_module = &module(); @@ -118,15 +74,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { op::Tuple(op::Constant(), op::Constant())); } -TEST_F(WhileLoopSimplifierTest, - LoopWithZeroIterationTupleElementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), - op::Tuple(op::Constant(), op::Constant(), op::Constant())); -} - TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { MakeModuleWithSimpleLoop(/*num_iters=*/1); HloModule* the_module = &module(); @@ -135,15 +82,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { op::Tuple(op::Add(), op::Multiply())); } -TEST_F(WhileLoopSimplifierTest, - LoopWithOneIterationTupleELementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), - op::Tuple(op::Add(), op::Multiply(), op::Constant())); -} - TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { MakeModuleWithSimpleLoop(/*num_iters=*/2); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); @@ -426,6 +364,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { HloModule BodyHasNonTupleRoot BodyHasNonTupleRoot.passthrough { ROOT param = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param), index=1 } BodyHasNonTupleRoot.always_true { param.1 = (s32[], s32[]) parameter(0) @@ -443,38 +382,5 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, - LoopWithNonTupleBodyRootInstructionNotSimplified) { - const string hlo_string = R"( - HloModule SimpleLoop - SimpleLoop.body { - loop_var.1 = (s32[], s32[3]{0}) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - constant.1 = s32[] constant(1) - add = s32[] add(get-tuple-element.1, constant.1) - get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 - multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) - ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply), - custom_call_target="x" - } - SimpleLoop.condition { - loop_var.2 = (s32[], s32[3]{0}) parameter(0) - get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 - constant.2 = s32[] constant(44) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) - } - ENTRY SimpleLoop { - constant.3 = s32[] constant(42) - constant.4 = s32[3]{0} constant({0, 1, 2}) - tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) - ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= - SimpleLoop.condition, body=SimpleLoop.body - } - )"; - - ParseAndVerifyModule(hlo_string.c_str()); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); -} - } // namespace } // namespace xla |