aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2017-07-27 23:36:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-27 23:40:41 -0700
commit560f0d22797ddeb66e8c35fc1c6e24c822ad64e2 (patch)
tree6fcd22896d9f4230bd8f52d3810ae7dcd5ca0a8d /tensorflow/compiler
parent446450369b9f4375cf26c1c1bc3b9a3fd93059f5 (diff)
[XLA] Create CallInliner HLO pass to recursively force-inline kCall operations.
PiperOrigin-RevId: 163436255
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.cc156
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h48
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc76
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc36
5 files changed, 348 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 0b85d6febc..cef55f66dc 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -281,6 +281,38 @@ cc_library(
],
)
+cc_library(
+ name = "call_inliner",
+ srcs = ["call_inliner.cc"],
+ hdrs = ["call_inliner.h"],
+ deps = [
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "call_inliner_test",
+ size = "small",
+ srcs = ["call_inliner_test.cc"],
+ deps = [
+ ":call_inliner",
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_test(
name = "flatten_call_graph_test",
size = "small",
diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc
new file mode 100644
index 0000000000..817b59f762
--- /dev/null
+++ b/tensorflow/compiler/xla/service/call_inliner.cc
@@ -0,0 +1,156 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/call_inliner.h"
+
+#include <deque>
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+StatusOr<bool> CallInliner::Run(HloModule* module) {
+ std::deque<HloInstruction*> work_queue;
+
+ // Seed the work queue with call instructions from the main computation.
+ TF_RETURN_IF_ERROR(
+ module->entry_computation()->Accept([&](HloInstruction* hlo) {
+ if (hlo->opcode() == HloOpcode::kCall) {
+ work_queue.push_back(hlo);
+ }
+ return Status::OK();
+ }));
+
+ VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
+
+ bool mutated = false;
+ while (!work_queue.empty()) {
+ mutated = true;
+ HloInstruction* call = work_queue.front();
+ work_queue.pop_front();
+ TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue));
+ }
+ return mutated;
+}
+
+// Traverses the callee computation, inlining cloned nodes into the caller
+// computation and connecting them to producers/consumers appropriately.
+// When the traversal has completed, the provided call instruction is entriely
+// replaced in the caller's graph, and any calls encountered in the callee
+// computation have been added to the work_queue.
+class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
+ public:
+ SubcomputationInsertionVisitor(HloInstruction* call,
+ std::deque<HloInstruction*>* work_queue)
+ : call_(call), outer_(call->parent()), work_queue_(work_queue) {}
+
+ // Resolves the operands to the HLO instruction in the inlined (caller) graph,
+ // and clones the HLO instruction into that graph with the new operands.
+ // If the instruction is a call, it is added to the work queue.
+ Status DefaultAction(HloInstruction* hlo) override {
+ std::vector<HloInstruction*> new_operands;
+ for (HloInstruction* operand : hlo->operands()) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
+ new_operands.push_back(new_operand);
+ }
+ VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString();
+ auto new_hlo = hlo->CloneWithNewOperands(hlo->shape(), new_operands);
+ HloInstruction* new_hlo_pointer =
+ outer_->AddInstruction(std::move(new_hlo));
+ TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer));
+
+ // Account for control edges.
+ for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor,
+ Resolve(control_predecessor));
+ TF_RETURN_IF_ERROR(
+ new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
+ }
+
+ if (new_hlo_pointer->opcode() == HloOpcode::kCall) {
+ VLOG(1) << "Adding new call HLO to work queue.";
+ // Call instructions we observe in the subcomputation are added to the
+ // inliner work queue.
+ work_queue_->push_back(new_hlo_pointer);
+ }
+ return Status::OK();
+ }
+
+ // Does not create new nodes for the parameter; rather, notes the mapping from
+ // the subcomputation parameter node to the call operands in the caller
+ // computation.
+ Status HandleParameter(HloInstruction* parameter) override {
+ TF_RETURN_IF_ERROR(NoteMapping(
+ parameter, call_->mutable_operand(parameter->parameter_number())));
+ return Status::OK();
+ }
+
+ // Wires the consumers of the call to instead point at the newly created root,
+ // replacing the call operation in the caller computation.
+ Status FinishVisit(HloInstruction* root) override {
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root));
+ VLOG(1) << "Replacing all uses of " << call_->ToString()
+ << " with new root " << new_root->ToString();
+ return outer_->ReplaceInstruction(call_, new_root);
+ }
+
+ private:
+ // Resolves the callee subcomputation_hlo to the new (inline) HLO in the
+ // caller computation, or returns a NotFound error if that subcomputation HLO
+ // has not been mapped.
+ StatusOr<HloInstruction*> Resolve(HloInstruction* subcomputation_hlo) {
+ auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo);
+ if (it == subcomputation_hlo_to_new_hlo_.end()) {
+ return NotFound(
+ "Could not find mapping from subcomputation HLO %s to a cloned HLO.",
+ subcomputation_hlo->ToString().c_str());
+ }
+ return it->second;
+ }
+
+ // Notes that the given subcomputation_hlo in the callee has been mapped to
+ // the (inline) new_hlo in the caller computation.
+ //
+ // Returns an error status if the subcomputation_hlo is mapped more than
+ // once.
+ Status NoteMapping(HloInstruction* subcomputation_hlo,
+ HloInstruction* new_hlo) {
+ auto result = subcomputation_hlo_to_new_hlo_.insert(
+ std::make_pair(subcomputation_hlo, new_hlo));
+ TF_RET_CHECK(result.second)
+ << "A mapping for the subcomputation HLO is already present.";
+ return Status::OK();
+ }
+
+ HloInstruction* call_;
+ HloComputation* outer_;
+ std::unordered_map<HloInstruction*, HloInstruction*>
+ subcomputation_hlo_to_new_hlo_;
+ std::deque<HloInstruction*>* work_queue_;
+};
+
+Status CallInliner::ReplaceWithInlinedBody(
+ HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
+ TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
+ TF_RET_CHECK(call->called_computations().size() == 1);
+ HloComputation* called = call->called_computations()[0];
+ VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of "
+ << called->name();
+
+ SubcomputationInsertionVisitor visitor(call, work_queue);
+ return called->Accept(&visitor);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
new file mode 100644
index 0000000000..8647edffa7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
+
+#include <deque>
+
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// For every kCall operation in the main computation, we inline the body of the
+// called function, and proceed recursively.
+class CallInliner : public HloPassInterface {
+ public:
+ ~CallInliner() override = default;
+ tensorflow::StringPiece name() const override { return "CallInliner"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ // Replaces the given call operation -- which must be an operation inside the
+ // entry computation with opcode kCall -- with the called computation's body,
+ // such that the called computation is inline in the entry computation.
+ //
+ // On successful inlining, the inlined computation may have itself contained
+ // calls; if so, they are added to the work_queue.
+ Status ReplaceWithInlinedBody(HloInstruction* call,
+ std::deque<HloInstruction*>* work_queue);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
new file mode 100644
index 0000000000..e6defa78ac
--- /dev/null
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/call_inliner.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace op = xla::testing::opcode_matchers;
+
+namespace xla {
+namespace {
+
+// Tests for call inlining that are most tractable at the HLO level (vs
+// ComputationBuilder API in call_test.cc).
+using CallInlinerTest = HloTestBase;
+
+TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
+ HloComputation::Builder inner(TestName() + ".inner");
+ HloInstruction* zero = inner.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
+ HloInstruction* one = inner.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ TF_ASSERT_OK(zero->AddControlDependencyTo(one));
+ std::unique_ptr<HloComputation> inner_computation(inner.Build());
+
+ HloComputation::Builder outer(TestName() + ".outer");
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ outer.AddInstruction(
+ HloInstruction::CreateCall(r0f32, {}, inner_computation.get()));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(outer.Build());
+
+ CallInliner call_inliner;
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ ASSERT_TRUE(mutated);
+ EXPECT_THAT(computation->root_instruction(), op::Constant());
+ EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
+ 42);
+ ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size());
+ auto prior = computation->root_instruction()->control_predecessors()[0];
+ EXPECT_THAT(prior, op::Constant());
+ EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index a9919b2fbf..214bc79198 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -73,6 +73,7 @@ class CallOpTest : public ClientLibraryTestBase {
Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
};
+// TODO(b/64094172) Failing on GPU as of 2017-07-26.
XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) {
ComputationBuilder builder(client_, TestName());
Computation callee = CreateR0F32IdentityComputation();
@@ -82,6 +83,7 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) {
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
}
+// TODO(b/64094172) Failing on GPU as of 2017-07-26.
XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) {
ComputationBuilder builder(client_, TestName());
Computation callee = CreateR1S0F32AdditionComputation();
@@ -92,6 +94,7 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) {
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
}
+// TODO(b/64094172) Failing on GPU as of 2017-07-26.
XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) {
ComputationBuilder builder(client_, TestName());
Computation callee = CreateR1S2F32AdditionComputation();
@@ -102,6 +105,39 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) {
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
}
+// TODO(b/64094172) Failing on GPU as of 2017-07-26.
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallTreeTwoDeepBranchFactorThree)) {
+ ComputationBuilder builder(client_, "inner");
+ {
+ auto x = builder.Parameter(0, r0f32_, "x");
+ builder.Add(x, builder.ConstantR0<float>(1.0));
+ }
+ TF_ASSERT_OK_AND_ASSIGN(Computation inner, builder.Build());
+
+ ComputationBuilder builder2(client_, "outer");
+ {
+ auto x = builder2.Parameter(0, r0f32_, "x");
+ x = builder2.Call(inner, {x});
+ x = builder2.Call(inner, {x});
+ x = builder2.Call(inner, {x});
+ }
+ TF_ASSERT_OK_AND_ASSIGN(Computation outer, builder2.Build());
+
+ ComputationBuilder builder3(client_, "outermost");
+ {
+ auto x = builder3.Parameter(0, r0f32_, "x");
+ x = builder3.Call(outer, {x});
+ x = builder3.Call(outer, {x});
+ x = builder3.Call(outer, {x});
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> start,
+ client_->TransferToServer(*Literal::CreateR0<float>(1.0f)));
+ ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
+}
+
+// TODO(b/64094172) Failing on GPU as of 2017-07-26.
XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) {
ComputationBuilder builder(client_, TestName());
Computation callee = CreateR0F32TupleComputation();