diff options
author | 2017-07-27 23:36:58 -0700 | |
---|---|---|
committer | 2017-07-27 23:40:41 -0700 | |
commit | 560f0d22797ddeb66e8c35fc1c6e24c822ad64e2 (patch) | |
tree | 6fcd22896d9f4230bd8f52d3810ae7dcd5ca0a8d /tensorflow/compiler | |
parent | 446450369b9f4375cf26c1c1bc3b9a3fd93059f5 (diff) |
[XLA] Create CallInliner HLO pass to recursively force-inline kCall operations.
PiperOrigin-RevId: 163436255
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/call_inliner.cc | 156 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/call_inliner.h | 48 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/call_inliner_test.cc | 76 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/call_test.cc | 36 |
5 files changed, 348 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0b85d6febc..cef55f66dc 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -281,6 +281,38 @@ cc_library( ], ) +cc_library( + name = "call_inliner", + srcs = ["call_inliner.cc"], + hdrs = ["call_inliner.h"], + deps = [ + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "call_inliner_test", + size = "small", + srcs = ["call_inliner_test.cc"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_test( name = "flatten_call_graph_test", size = "small", diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc new file mode 100644 index 0000000000..817b59f762 --- /dev/null +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/call_inliner.h" + +#include <deque> + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr<bool> CallInliner::Run(HloModule* module) { + std::deque<HloInstruction*> work_queue; + + // Seed the work queue with call instructions from the main computation. + TF_RETURN_IF_ERROR( + module->entry_computation()->Accept([&](HloInstruction* hlo) { + if (hlo->opcode() == HloOpcode::kCall) { + work_queue.push_back(hlo); + } + return Status::OK(); + })); + + VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries."; + + bool mutated = false; + while (!work_queue.empty()) { + mutated = true; + HloInstruction* call = work_queue.front(); + work_queue.pop_front(); + TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue)); + } + return mutated; +} + +// Traverses the callee computation, inlining cloned nodes into the caller +// computation and connecting them to producers/consumers appropriately. +// When the traversal has completed, the provided call instruction is entriely +// replaced in the caller's graph, and any calls encountered in the callee +// computation have been added to the work_queue. +class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { + public: + SubcomputationInsertionVisitor(HloInstruction* call, + std::deque<HloInstruction*>* work_queue) + : call_(call), outer_(call->parent()), work_queue_(work_queue) {} + + // Resolves the operands to the HLO instruction in the inlined (caller) graph, + // and clones the HLO instruction into that graph with the new operands. + // If the instruction is a call, it is added to the work queue. + Status DefaultAction(HloInstruction* hlo) override { + std::vector<HloInstruction*> new_operands; + for (HloInstruction* operand : hlo->operands()) { + TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); + new_operands.push_back(new_operand); + } + VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString(); + auto new_hlo = hlo->CloneWithNewOperands(hlo->shape(), new_operands); + HloInstruction* new_hlo_pointer = + outer_->AddInstruction(std::move(new_hlo)); + TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer)); + + // Account for control edges. + for (HloInstruction* control_predecessor : hlo->control_predecessors()) { + TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor, + Resolve(control_predecessor)); + TF_RETURN_IF_ERROR( + new_control_predecessor->AddControlDependencyTo(new_hlo_pointer)); + } + + if (new_hlo_pointer->opcode() == HloOpcode::kCall) { + VLOG(1) << "Adding new call HLO to work queue."; + // Call instructions we observe in the subcomputation are added to the + // inliner work queue. + work_queue_->push_back(new_hlo_pointer); + } + return Status::OK(); + } + + // Does not create new nodes for the parameter; rather, notes the mapping from + // the subcomputation parameter node to the call operands in the caller + // computation. + Status HandleParameter(HloInstruction* parameter) override { + TF_RETURN_IF_ERROR(NoteMapping( + parameter, call_->mutable_operand(parameter->parameter_number()))); + return Status::OK(); + } + + // Wires the consumers of the call to instead point at the newly created root, + // replacing the call operation in the caller computation. + Status FinishVisit(HloInstruction* root) override { + TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root)); + VLOG(1) << "Replacing all uses of " << call_->ToString() + << " with new root " << new_root->ToString(); + return outer_->ReplaceInstruction(call_, new_root); + } + + private: + // Resolves the callee subcomputation_hlo to the new (inline) HLO in the + // caller computation, or returns a NotFound error if that subcomputation HLO + // has not been mapped. + StatusOr<HloInstruction*> Resolve(HloInstruction* subcomputation_hlo) { + auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo); + if (it == subcomputation_hlo_to_new_hlo_.end()) { + return NotFound( + "Could not find mapping from subcomputation HLO %s to a cloned HLO.", + subcomputation_hlo->ToString().c_str()); + } + return it->second; + } + + // Notes that the given subcomputation_hlo in the callee has been mapped to + // the (inline) new_hlo in the caller computation. + // + // Returns an error status if the subcomputation_hlo is mapped more than + // once. + Status NoteMapping(HloInstruction* subcomputation_hlo, + HloInstruction* new_hlo) { + auto result = subcomputation_hlo_to_new_hlo_.insert( + std::make_pair(subcomputation_hlo, new_hlo)); + TF_RET_CHECK(result.second) + << "A mapping for the subcomputation HLO is already present."; + return Status::OK(); + } + + HloInstruction* call_; + HloComputation* outer_; + std::unordered_map<HloInstruction*, HloInstruction*> + subcomputation_hlo_to_new_hlo_; + std::deque<HloInstruction*>* work_queue_; +}; + +Status CallInliner::ReplaceWithInlinedBody( + HloInstruction* call, std::deque<HloInstruction*>* work_queue) { + TF_RET_CHECK(call->opcode() == HloOpcode::kCall); + TF_RET_CHECK(call->called_computations().size() == 1); + HloComputation* called = call->called_computations()[0]; + VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of " + << called->name(); + + SubcomputationInsertionVisitor visitor(call, work_queue); + return called->Accept(&visitor); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h new file mode 100644 index 0000000000..8647edffa7 --- /dev/null +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ + +#include <deque> + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// For every kCall operation in the main computation, we inline the body of the +// called function, and proceed recursively. +class CallInliner : public HloPassInterface { + public: + ~CallInliner() override = default; + tensorflow::StringPiece name() const override { return "CallInliner"; } + + StatusOr<bool> Run(HloModule* module) override; + + private: + // Replaces the given call operation -- which must be an operation inside the + // entry computation with opcode kCall -- with the called computation's body, + // such that the called computation is inline in the entry computation. + // + // On successful inlining, the inlined computation may have itself contained + // calls; if so, they are added to the work_queue. + Status ReplaceWithInlinedBody(HloInstruction* call, + std::deque<HloInstruction*>* work_queue); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc new file mode 100644 index 0000000000..e6defa78ac --- /dev/null +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/call_inliner.h" + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +// Tests for call inlining that are most tractable at the HLO level (vs +// ComputationBuilder API in call_test.cc). +using CallInlinerTest = HloTestBase; + +TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { + HloComputation::Builder inner(TestName() + ".inner"); + HloInstruction* zero = inner.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f))); + HloInstruction* one = inner.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + TF_ASSERT_OK(zero->AddControlDependencyTo(one)); + std::unique_ptr<HloComputation> inner_computation(inner.Build()); + + HloComputation::Builder outer(TestName() + ".outer"); + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + outer.AddInstruction( + HloInstruction::CreateCall(r0f32, {}, inner_computation.get())); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(outer.Build()); + + CallInliner call_inliner; + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_TRUE(mutated); + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(), + 42); + ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size()); + auto prior = computation->root_instruction()->control_predecessors()[0]; + EXPECT_THAT(prior, op::Constant()); + EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index a9919b2fbf..214bc79198 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -73,6 +73,7 @@ class CallOpTest : public ClientLibraryTestBase { Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2}); }; +// TODO(b/64094172) Failing on GPU as of 2017-07-26. XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32IdentityComputation(); @@ -82,6 +83,7 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f)); } +// TODO(b/64094172) Failing on GPU as of 2017-07-26. XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S0F32AdditionComputation(); @@ -92,6 +94,7 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f)); } +// TODO(b/64094172) Failing on GPU as of 2017-07-26. XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S2F32AdditionComputation(); @@ -102,6 +105,39 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); } +// TODO(b/64094172) Failing on GPU as of 2017-07-26. +XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallTreeTwoDeepBranchFactorThree)) { + ComputationBuilder builder(client_, "inner"); + { + auto x = builder.Parameter(0, r0f32_, "x"); + builder.Add(x, builder.ConstantR0<float>(1.0)); + } + TF_ASSERT_OK_AND_ASSIGN(Computation inner, builder.Build()); + + ComputationBuilder builder2(client_, "outer"); + { + auto x = builder2.Parameter(0, r0f32_, "x"); + x = builder2.Call(inner, {x}); + x = builder2.Call(inner, {x}); + x = builder2.Call(inner, {x}); + } + TF_ASSERT_OK_AND_ASSIGN(Computation outer, builder2.Build()); + + ComputationBuilder builder3(client_, "outermost"); + { + auto x = builder3.Parameter(0, r0f32_, "x"); + x = builder3.Call(outer, {x}); + x = builder3.Call(outer, {x}); + x = builder3.Call(outer, {x}); + } + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<GlobalData> start, + client_->TransferToServer(*Literal::CreateR0<float>(1.0f))); + ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); +} + +// TODO(b/64094172) Failing on GPU as of 2017-07-26. XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32TupleComputation(); |