diff options
author | 2018-03-09 10:33:28 -0800 | |
---|---|---|
committer | 2018-03-09 10:38:50 -0800 | |
commit | 87dab2d8289750c9d34f26d7d5fb18475dff985b (patch) | |
tree | 17d1d0de205110553f015946fdd45a488ad325d7 /tensorflow/compiler/xla | |
parent | 58d5fa05a67b65979708f541336c2c11bfed978e (diff) |
Automated g4 rollback of changelist 188397087
PiperOrigin-RevId: 188503184
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier.cc | 76 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_simplifier_test.cc | 96 |
2 files changed, 2 insertions, 170 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 1a93a880dd..c9d77c9376 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -606,75 +605,6 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) { return false; } -static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) { - auto while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - return false; - } - - auto while_body = while_op->while_body(); - auto while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - return false; - } - - auto while_body_param = while_body->parameter_instruction(0); - const HloInstruction::InstructionVector& root_operands = - while_body_root->operands(); - - // Find the loop invariant tuple elements with constant init value and - // build a map from the tuple element index to the constant value. - tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant; - for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && instr->operand(0) == while_body_param) { - auto tuple_element = while_init->operand(i); - if (tuple_element->IsConstant()) { - VLOG(3) << "Found loop invariant tuple element " << i << " " - << tuple_element->ToString(); - index_to_constant[i] = tuple_element; - } - } - } - - if (index_to_constant.empty()) { - return false; - } - - // Replace the use of each constant tuple element in the loop_condition and - // loop_body with the corresponding constant value. - auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> { - HloInstruction* param = computation->parameter_instruction(0); - bool changed = false; - for (auto instr : param->users()) { - // Since only a while-loop with a tuple result reaches here, we can safely - // assume that `param` is a tuple and the first operand of the - // GetTupleElement instruction is a use of `param`. - if (instr->opcode() == HloOpcode::kGetTupleElement) { - VLOG(3) << "tuple index " << instr->tuple_index() << " " - << instr->ToString(); - auto iter = index_to_constant.find(instr->tuple_index()); - if (iter != index_to_constant.end()) { - const HloInstruction* hlo_constant = (*iter).second; - VLOG(3) << "Replace use of " << instr->ToString() << " with " - << hlo_constant->ToString(); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith( - computation->AddInstruction(hlo_constant->Clone()))); - changed = true; - } - } - } - return changed; - }; - - TF_ASSIGN_OR_RETURN(bool changed_cond, - propagate_constant(while_op->while_condition())); - TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); - - return changed_cond || changed_body; -} - StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -705,11 +635,7 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) { continue; } - StatusOr<bool> result = TryPropagateConstant(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); - - result = TryRemoveWhileLoop(while_op); + StatusOr<bool> result = TryRemoveWhileLoop(while_op); TF_RETURN_IF_ERROR(result.status()); if (result.ValueOrDie()) { changed = true; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 396f942dc0..cbea3e3cf2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -30,11 +30,6 @@ class WhileLoopSimplifierTest : public HloVerifiedTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); - - // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to - // the loop-condition through an element of a tuple which is the - // loop-condition parameter. - void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { @@ -71,45 +66,6 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { ParseAndVerifyModule(hlo_string.c_str()); } -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( - int num_iters) { - string hlo_string_template = R"( - HloModule SimpleLoopWithIndirectLoopBound - SimpleLoopWithIndirectLoopBound.body { - loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - constant.1 = s32[] constant(1) - add = s32[] add(get-tuple-element.1, constant.1) - get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 - multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) - limit = s32[] get-tuple-element(loop_var.1), index=2 - ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit) - } - SimpleLoopWithIndirectLoopBound.condition { - loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) - get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 - get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 - ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) - } - ENTRY SimpleLoopWithIndirectLoopBound { - constant.3 = s32[] constant(42) - constant.4 = s32[3]{0} constant({0, 1, 2}) - constant.2 = s32[] constant({{LOOP_BOUND}}) - tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4, - constant.2) - ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1), - condition=SimpleLoopWithIndirectLoopBound.condition, - body=SimpleLoopWithIndirectLoopBound.body - } - )"; - - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); - ParseAndVerifyModule(hlo_string.c_str()); -} - TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { MakeModuleWithSimpleLoop(/*num_iters=*/0); HloModule* the_module = &module(); @@ -118,15 +74,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { op::Tuple(op::Constant(), op::Constant())); } -TEST_F(WhileLoopSimplifierTest, - LoopWithZeroIterationTupleElementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), - op::Tuple(op::Constant(), op::Constant(), op::Constant())); -} - TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { MakeModuleWithSimpleLoop(/*num_iters=*/1); HloModule* the_module = &module(); @@ -135,15 +82,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { op::Tuple(op::Add(), op::Multiply())); } -TEST_F(WhileLoopSimplifierTest, - LoopWithOneIterationTupleELementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), - op::Tuple(op::Add(), op::Multiply(), op::Constant())); -} - TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { MakeModuleWithSimpleLoop(/*num_iters=*/2); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); @@ -426,6 +364,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { HloModule BodyHasNonTupleRoot BodyHasNonTupleRoot.passthrough { ROOT param = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param), index=1 } BodyHasNonTupleRoot.always_true { param.1 = (s32[], s32[]) parameter(0) @@ -443,38 +382,5 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, - LoopWithNonTupleBodyRootInstructionNotSimplified) { - const string hlo_string = R"( - HloModule SimpleLoop - SimpleLoop.body { - loop_var.1 = (s32[], s32[3]{0}) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - constant.1 = s32[] constant(1) - add = s32[] add(get-tuple-element.1, constant.1) - get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 - multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) - ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply), - custom_call_target="x" - } - SimpleLoop.condition { - loop_var.2 = (s32[], s32[3]{0}) parameter(0) - get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 - constant.2 = s32[] constant(44) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) - } - ENTRY SimpleLoop { - constant.3 = s32[] constant(42) - constant.4 = s32[3]{0} constant({0, 1, 2}) - tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) - ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= - SimpleLoop.condition, body=SimpleLoop.body - } - )"; - - ParseAndVerifyModule(hlo_string.c_str()); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); -} - } // namespace } // namespace xla |