aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 10:33:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 10:38:50 -0800
commit87dab2d8289750c9d34f26d7d5fb18475dff985b (patch)
tree17d1d0de205110553f015946fdd45a488ad325d7 /tensorflow/compiler/xla
parent58d5fa05a67b65979708f541336c2c11bfed978e (diff)
Automated g4 rollback of changelist 188397087
PiperOrigin-RevId: 188503184
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc76
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc96
2 files changed, 2 insertions, 170 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 1a93a880dd..c9d77c9376 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -606,75 +605,6 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
return false;
}
-static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
- auto while_init = while_op->operand(0);
- if (while_init->opcode() != HloOpcode::kTuple) {
- return false;
- }
-
- auto while_body = while_op->while_body();
- auto while_body_root = while_body->root_instruction();
- if (while_body_root->opcode() != HloOpcode::kTuple) {
- return false;
- }
-
- auto while_body_param = while_body->parameter_instruction(0);
- const HloInstruction::InstructionVector& root_operands =
- while_body_root->operands();
-
- // Find the loop invariant tuple elements with constant init value and
- // build a map from the tuple element index to the constant value.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant;
- for (int i = 0; i < root_operands.size(); i++) {
- HloInstruction* instr = root_operands[i];
- if (instr->opcode() == HloOpcode::kGetTupleElement &&
- instr->tuple_index() == i && instr->operand(0) == while_body_param) {
- auto tuple_element = while_init->operand(i);
- if (tuple_element->IsConstant()) {
- VLOG(3) << "Found loop invariant tuple element " << i << " "
- << tuple_element->ToString();
- index_to_constant[i] = tuple_element;
- }
- }
- }
-
- if (index_to_constant.empty()) {
- return false;
- }
-
- // Replace the use of each constant tuple element in the loop_condition and
- // loop_body with the corresponding constant value.
- auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
- HloInstruction* param = computation->parameter_instruction(0);
- bool changed = false;
- for (auto instr : param->users()) {
- // Since only a while-loop with a tuple result reaches here, we can safely
- // assume that `param` is a tuple and the first operand of the
- // GetTupleElement instruction is a use of `param`.
- if (instr->opcode() == HloOpcode::kGetTupleElement) {
- VLOG(3) << "tuple index " << instr->tuple_index() << " "
- << instr->ToString();
- auto iter = index_to_constant.find(instr->tuple_index());
- if (iter != index_to_constant.end()) {
- const HloInstruction* hlo_constant = (*iter).second;
- VLOG(3) << "Replace use of " << instr->ToString() << " with "
- << hlo_constant->ToString();
- TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
- computation->AddInstruction(hlo_constant->Clone())));
- changed = true;
- }
- }
- }
- return changed;
- };
-
- TF_ASSIGN_OR_RETURN(bool changed_cond,
- propagate_constant(while_op->while_condition()));
- TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
-
- return changed_cond || changed_body;
-}
-
StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(3,
"WhileLoopSimplifier::Run(), before:\n" + module->ToString());
@@ -705,11 +635,7 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
continue;
}
- StatusOr<bool> result = TryPropagateConstant(while_op);
- TF_RETURN_IF_ERROR(result.status());
- changed |= result.ValueOrDie();
-
- result = TryRemoveWhileLoop(while_op);
+ StatusOr<bool> result = TryRemoveWhileLoop(while_op);
TF_RETURN_IF_ERROR(result.status());
if (result.ValueOrDie()) {
changed = true;
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 396f942dc0..cbea3e3cf2 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -30,11 +30,6 @@ class WhileLoopSimplifierTest : public HloVerifiedTestBase {
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
-
- // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to
- // the loop-condition through an element of a tuple which is the
- // loop-condition parameter.
- void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters);
};
void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
@@ -71,45 +66,6 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
ParseAndVerifyModule(hlo_string.c_str());
}
-void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
- int num_iters) {
- string hlo_string_template = R"(
- HloModule SimpleLoopWithIndirectLoopBound
- SimpleLoopWithIndirectLoopBound.body {
- loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0)
- get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
- constant.1 = s32[] constant(1)
- add = s32[] add(get-tuple-element.1, constant.1)
- get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
- multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
- limit = s32[] get-tuple-element(loop_var.1), index=2
- ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit)
- }
- SimpleLoopWithIndirectLoopBound.condition {
- loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
- get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
- get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
- ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
- }
- ENTRY SimpleLoopWithIndirectLoopBound {
- constant.3 = s32[] constant(42)
- constant.4 = s32[3]{0} constant({0, 1, 2})
- constant.2 = s32[] constant({{LOOP_BOUND}})
- tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4,
- constant.2)
- ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1),
- condition=SimpleLoopWithIndirectLoopBound.condition,
- body=SimpleLoopWithIndirectLoopBound.body
- }
- )";
-
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
- ParseAndVerifyModule(hlo_string.c_str());
-}
-
TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/0);
HloModule* the_module = &module();
@@ -118,15 +74,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
op::Tuple(op::Constant(), op::Constant()));
}
-TEST_F(WhileLoopSimplifierTest,
- LoopWithZeroIterationTupleElementLoopBoundSimplified) {
- MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0);
- HloModule* the_module = &module();
- ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
- EXPECT_THAT(the_module->entry_computation()->root_instruction(),
- op::Tuple(op::Constant(), op::Constant(), op::Constant()));
-}
-
TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/1);
HloModule* the_module = &module();
@@ -135,15 +82,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
op::Tuple(op::Add(), op::Multiply()));
}
-TEST_F(WhileLoopSimplifierTest,
- LoopWithOneIterationTupleELementLoopBoundSimplified) {
- MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1);
- HloModule* the_module = &module();
- ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
- EXPECT_THAT(the_module->entry_computation()->root_instruction(),
- op::Tuple(op::Add(), op::Multiply(), op::Constant()));
-}
-
TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/2);
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
@@ -426,6 +364,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
HloModule BodyHasNonTupleRoot
BodyHasNonTupleRoot.passthrough {
ROOT param = (s32[], s32[]) parameter(0)
+ get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param), index=1
}
BodyHasNonTupleRoot.always_true {
param.1 = (s32[], s32[]) parameter(0)
@@ -443,38 +382,5 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
-TEST_F(WhileLoopSimplifierTest,
- LoopWithNonTupleBodyRootInstructionNotSimplified) {
- const string hlo_string = R"(
- HloModule SimpleLoop
- SimpleLoop.body {
- loop_var.1 = (s32[], s32[3]{0}) parameter(0)
- get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
- constant.1 = s32[] constant(1)
- add = s32[] add(get-tuple-element.1, constant.1)
- get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
- multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
- ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply),
- custom_call_target="x"
- }
- SimpleLoop.condition {
- loop_var.2 = (s32[], s32[3]{0}) parameter(0)
- get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
- constant.2 = s32[] constant(44)
- ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
- }
- ENTRY SimpleLoop {
- constant.3 = s32[] constant(42)
- constant.4 = s32[3]{0} constant({0, 1, 2})
- tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
- ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
- SimpleLoop.condition, body=SimpleLoop.body
- }
- )";
-
- ParseAndVerifyModule(hlo_string.c_str());
- EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
-}
-
} // namespace
} // namespace xla