diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-04-14 09:51:39 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-14 11:07:21 -0700 |
commit | 9c67e16d7f319363804977b106e33faa972ed89f (patch) | |
tree | 4821be95b7e901bb1d8a5b62d2ac13f95c74c389 /tensorflow/compiler/xla/service/flatten_call_graph.cc | |
parent | 4254fa13de1697226e15df26cc101d00d1cb1a03 (diff) |
[XLA] Flatten computation call graph
This CL clones computations that are called from >1 call sites in a sequential context (call, while nodes) so that the call graph becomes a tree.
Change: 153183115
Diffstat (limited to 'tensorflow/compiler/xla/service/flatten_call_graph.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/flatten_call_graph.cc | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc new file mode 100644 index 0000000000..3c41fe870f --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Helper to replace the called computation at a while- or call-instruction. +void ReplaceCalledComputation(HloInstruction* instruction, + HloComputation* computation, + HloComputation* new_computation) { + switch (instruction->opcode()) { + case HloOpcode::kWhile: { + if (computation == instruction->while_condition()) { + instruction->set_while_condition(new_computation); + } else { + CHECK_EQ(computation, instruction->while_body()); + instruction->set_while_body(new_computation); + } + break; + } + case HloOpcode::kCall: { + CHECK_EQ(instruction->to_apply(), computation); + instruction->set_to_apply(new_computation); + break; + } + default: + LOG(FATAL) << "unexpected opcode: " + << HloOpcodeString(instruction->opcode()); + } +} + +// Flatten a single call graph node. Expects to visit nodes in postorder. +Status FlattenNode(const CallGraphNode& node) { + HloComputation* computation = node.computation(); + HloModule* module = computation->parent(); + // Clone callee for all call-sites except the first one. + for (int i = 0; i < node.caller_callsites().size(); ++i) { + CallSite call_site = node.caller_callsites()[i]; + // Only consider sequential call contexts. + if (call_site.context() == CallContext::kParallel) { + continue; + } + CHECK_EQ(call_site.context(), CallContext::kSequential); + + // Skip first element if this computation is only called from a sequential + // context. + if (node.context() != CallContext::kBoth && i == 0) { + continue; + } + + // Clone computation for the remaining sequential context call sites. + HloComputation* clone = + module->AddEmbeddedComputation(computation->Clone()); + ReplaceCalledComputation(call_site.instruction(), computation, clone); + // Clone the sub-tree of all computations called from this node. + std::vector<HloComputation*> worklist; + worklist.push_back(clone); + while (!worklist.empty()) { + auto current = worklist.back(); + worklist.pop_back(); + for (auto& instruction : current->instructions()) { + if (GetInstructionCallContext(instruction.get()) != + CallContext::kSequential) { + continue; + } + for (auto callee : instruction->called_computations()) { + HloComputation* callee_clone = + module->AddEmbeddedComputation(callee->Clone()); + ReplaceCalledComputation(instruction.get(), callee, callee_clone); + worklist.push_back(callee_clone); + } + } + } + } + return Status::OK(); +} + +} // namespace + +StatusOr<bool> FlattenCallGraph::Run(HloModule* module) { + XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr<CallGraph> call_graph, + CallGraph::Build(module)); + TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode)); + + XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString()); + return true; +} + +} // namespace xla |