aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/flatten_call_graph.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-14 09:51:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-14 11:07:21 -0700
commit9c67e16d7f319363804977b106e33faa972ed89f (patch)
tree4821be95b7e901bb1d8a5b62d2ac13f95c74c389 /tensorflow/compiler/xla/service/flatten_call_graph.cc
parent4254fa13de1697226e15df26cc101d00d1cb1a03 (diff)
[XLA] Flatten computation call graph
This CL clones computations that are called from >1 call sites in a sequential context (call, while nodes) so that the call graph becomes a tree. Change: 153183115
Diffstat (limited to 'tensorflow/compiler/xla/service/flatten_call_graph.cc')
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.cc113
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc
new file mode 100644
index 0000000000..3c41fe870f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
+
+#include "tensorflow/compiler/xla/service/call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+namespace {
+
+// Helper to replace the called computation at a while- or call-instruction.
+void ReplaceCalledComputation(HloInstruction* instruction,
+ HloComputation* computation,
+ HloComputation* new_computation) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kWhile: {
+ if (computation == instruction->while_condition()) {
+ instruction->set_while_condition(new_computation);
+ } else {
+ CHECK_EQ(computation, instruction->while_body());
+ instruction->set_while_body(new_computation);
+ }
+ break;
+ }
+ case HloOpcode::kCall: {
+ CHECK_EQ(instruction->to_apply(), computation);
+ instruction->set_to_apply(new_computation);
+ break;
+ }
+ default:
+ LOG(FATAL) << "unexpected opcode: "
+ << HloOpcodeString(instruction->opcode());
+ }
+}
+
+// Flatten a single call graph node. Expects to visit nodes in postorder.
+Status FlattenNode(const CallGraphNode& node) {
+ HloComputation* computation = node.computation();
+ HloModule* module = computation->parent();
+ // Clone callee for all call-sites except the first one.
+ for (int i = 0; i < node.caller_callsites().size(); ++i) {
+ CallSite call_site = node.caller_callsites()[i];
+ // Only consider sequential call contexts.
+ if (call_site.context() == CallContext::kParallel) {
+ continue;
+ }
+ CHECK_EQ(call_site.context(), CallContext::kSequential);
+
+ // Skip first element if this computation is only called from a sequential
+ // context.
+ if (node.context() != CallContext::kBoth && i == 0) {
+ continue;
+ }
+
+ // Clone computation for the remaining sequential context call sites.
+ HloComputation* clone =
+ module->AddEmbeddedComputation(computation->Clone());
+ ReplaceCalledComputation(call_site.instruction(), computation, clone);
+ // Clone the sub-tree of all computations called from this node.
+ std::vector<HloComputation*> worklist;
+ worklist.push_back(clone);
+ while (!worklist.empty()) {
+ auto current = worklist.back();
+ worklist.pop_back();
+ for (auto& instruction : current->instructions()) {
+ if (GetInstructionCallContext(instruction.get()) !=
+ CallContext::kSequential) {
+ continue;
+ }
+ for (auto callee : instruction->called_computations()) {
+ HloComputation* callee_clone =
+ module->AddEmbeddedComputation(callee->Clone());
+ ReplaceCalledComputation(instruction.get(), callee, callee_clone);
+ worklist.push_back(callee_clone);
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+StatusOr<bool> FlattenCallGraph::Run(HloModule* module) {
+ XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString());
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<CallGraph> call_graph,
+ CallGraph::Build(module));
+ TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode));
+
+ XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString());
+ return true;
+}
+
+} // namespace xla