aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2018-02-28 13:43:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 13:47:48 -0800
commit0f3105c39b079d8e7741e48e3b098c47c81a453a (patch)
treef4cd1189718b083bd78c1120e451a2f0ad8dbfa2
parentc661f2c3de75e3ad58bce52b39b8cc2e7ee07c0e (diff)
[XLA] Add a HLO simplifier pass to fold Conditional(constant_predicate,
true_computation, false_computation) to Call(predicated_computation) and finally inlined computation. PiperOrigin-RevId: 187376657
-rw-r--r--tensorflow/compiler/xla/service/BUILD35
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.cc106
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h38
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc153
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc2
8 files changed, 338 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e6a6e54927..e4ae812532 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1214,6 +1214,41 @@ tf_cc_test(
)
cc_library(
+ name = "conditional_simplifier",
+ srcs = ["conditional_simplifier.cc"],
+ hdrs = ["conditional_simplifier.h"],
+ deps = [
+ ":call_inliner",
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "conditional_simplifier_test",
+ srcs = ["conditional_simplifier_test.cc"],
+ deps = [
+ ":conditional_simplifier",
+ ":hlo",
+ ":hlo_matchers",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "while_loop_simplifier",
srcs = ["while_loop_simplifier.cc"],
hdrs = ["while_loop_simplifier.h"],
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
new file mode 100644
index 0000000000..f35de08085
--- /dev/null
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/call_inliner.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+
+// Tries to replace a conditional with a call operation of the corresponding
+// computation. If the given conditional has a constant predicate, tries to
+// replace it with a call to its true/false computation as appropirate and then
+// inline that computation.
+//
+// Returns true if it made a change to the graph.
+static StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
+ CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
+ // Do not remove conditionals that contain side-effecting instructions or
+ // have control predecessors/successors in either true/false computation.
+ if (!conditional->parent()->IsRemovable(conditional) ||
+ conditional->HasSideEffect()) {
+ VLOG(2) << "Not attempting to remove conditional as it is not removable or "
+ "has side effect: "
+ << conditional->ToShortString();
+ return false;
+ }
+
+ if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
+ VLOG(2) << "Not attempting to remove conditional as its predicate is not a "
+ "compile-time constant: "
+ << conditional->ToShortString();
+ return false;
+ }
+
+ auto computation = conditional->parent();
+ HloInstruction* call_op;
+ if (conditional->operand(0)->literal().Get<bool>({})) {
+ call_op = computation->AddInstruction(HloInstruction::CreateCall(
+ conditional->shape(), {conditional->mutable_operand(1)},
+ conditional->true_computation()));
+ } else {
+ call_op = computation->AddInstruction(HloInstruction::CreateCall(
+ conditional->shape(), {conditional->mutable_operand(2)},
+ conditional->false_computation()));
+ }
+
+ TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
+ TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
+
+ return true;
+}
+
+StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
+ XLA_VLOG_LINES(
+ 3, "ConditionalSimplifier::Run(), before:\n" + module->ToString());
+ bool changed = false;
+
+ // Gather all the conditional ops in our module. We do this ahead of time so
+ // we don't have to worry about mutating the lists of computations or
+ // instructions as we iterate.
+ std::vector<HloInstruction*> conditional_ops;
+ for (auto* comp : module->computations()) {
+ for (auto* instr : comp->instructions()) {
+ if (instr->opcode() == HloOpcode::kConditional) {
+ conditional_ops.push_back(instr);
+ }
+ }
+ }
+
+ for (HloInstruction* conditional_op : conditional_ops) {
+ TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op));
+ changed |= result;
+ }
+
+ XLA_VLOG_LINES(3,
+ "ConditionalSimplifier::Run(), after:\n" + module->ToString());
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
new file mode 100644
index 0000000000..063261e26d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+
+// HLO pass that removes kConditional with a constant predicate, replacing them
+// with their true or false computation as appropriate.
+class ConditionalSimplifier : public HloPassInterface {
+ public:
+ tensorflow::StringPiece name() const override {
+ return "simplify-conditional";
+ }
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
new file mode 100644
index 0000000000..868348547d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -0,0 +1,153 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+
+class ConditionalSimplifierTest : public HloVerifiedTestBase {
+ public:
+ // Makes a computation that contains a conditional with constant predicate.
+ HloComputation* MakeConditional(HloModule* module);
+};
+
+HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+
+ // true_computation returns param+1.
+ HloComputation* true_computation;
+ {
+ HloComputation::Builder true_computation_builder(TestName() +
+ ".true_computation");
+ auto param =
+ true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(S32, {}), "param"));
+ auto one = true_computation_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+
+ true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
+
+ true_computation =
+ module->AddEmbeddedComputation(true_computation_builder.Build());
+ }
+
+ // false_computation returns param+42.
+ HloComputation* false_computation;
+ {
+ HloComputation::Builder false_computation_builder(TestName() +
+ ".false_computation");
+ auto param = false_computation_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
+ "param"));
+ auto forty_two = false_computation_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+
+ false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
+ false_computation =
+ module->AddEmbeddedComputation(false_computation_builder.Build());
+ }
+
+ auto false_instrn = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(S32, {}), "false_param"));
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+
+ builder.AddInstruction(HloInstruction::CreateConditional(
+ ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
+ false_param, false_computation));
+
+ return module->AddEntryComputation(builder.Build());
+}
+
+TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
+ HloComputation* computation = MakeConditional(&module());
+ ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie());
+ EXPECT_THAT(computation->root_instruction(),
+ op::Add(op::Parameter(), op::Constant()));
+}
+
+TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
+ HloComputation* computation = MakeConditional(&module());
+
+ auto* true_op = computation->AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ TF_ASSERT_OK(
+ true_op->AddControlDependencyTo(computation->root_instruction()));
+
+ EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
+}
+
+TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
+ HloComputation* computation = MakeConditional(&module());
+ auto* conditional = computation->root_instruction();
+ ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
+
+ auto* true_computation = conditional->true_computation();
+ auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
+ true_computation->AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
+ /*channel_id=*/0));
+ true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
+ EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
+}
+
+TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
+ HloComputation* computation = MakeConditional(&module());
+ auto* conditional = computation->root_instruction();
+ ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
+
+ auto* true_computation = conditional->true_computation();
+ auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
+ ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0));
+ true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
+ EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
+}
+
+TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
+ HloComputation* computation = MakeConditional(&module());
+ auto* conditional = computation->root_instruction();
+ ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
+ auto* false_computation = conditional->false_computation();
+ false_computation->AddInstruction(
+ HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
+ EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 4170e31527..38a54fcb64 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -105,6 +105,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
+ "//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 387806e24a..0d15be5a23 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -47,6 +47,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
+#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
@@ -275,6 +276,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
pass.AddPass<HloDCE>();
pass.AddPass<ReshapeMover>();
pass.AddPass<HloConstantFolding>();
+ pass.AddPass<ConditionalSimplifier>();
}
pipeline.AddPass<TransposeFolding>(
[](const HloInstruction& dot,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 9da4fb97fa..334efff1e6 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -510,6 +510,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
+ "//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 28ebd034ee..9e37acdf31 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
+#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
@@ -176,6 +177,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
pass.AddPass<HloDCE>();
pass.AddPass<ReshapeMover>();
pass.AddPass<HloConstantFolding>();
+ pass.AddPass<ConditionalSimplifier>();
}
pipeline.AddPass<TransposeFolding>(