diff options
author | Mark Heffernan <meheff@google.com> | 2017-03-17 11:27:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-17 12:45:35 -0700 |
commit | d687f9bb40d978912ef8ddd531fa39e032de4c39 (patch) | |
tree | e14e1d16f52439e07752501694b388317c54aafd /tensorflow/compiler/xla/service/call_graph.cc | |
parent | c76bc807c3089069d520edd6ab7f3f07a0c76f53 (diff) |
[XLA] Add mapping from HloInstruction to CallGraphNode.
Make mapping from calling instruction to CallSite one-to-one by extending CallSite to handle more than one called computation. This enables instructions like kWhile which call two computations to be represented as a single CallSite. Also add a mapping from instruction to CallSite in CallGraphNode to enable fast call site lookup.
Also, include a few other opportunistic improvements:
* Change CallGraph::Build factor to return a std::unique_ptr. This enables,
for example, more convenient use of CallGraph as a data member to a class.
* Change a few uses of unordered_set/map to FlatSet/Map.
Change: 150469958
Diffstat (limited to 'tensorflow/compiler/xla/service/call_graph.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/call_graph.cc | 155 |
1 files changed, 89 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index fdd963a84a..f75aa9bd11 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,10 +17,13 @@ limitations under the License. #include <queue> +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" @@ -49,64 +52,79 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { } string CallSite::ToString() const { - return StrCat(instruction->name(), " calls ", called_computation->name(), - ", ", CallContextToString(context)); + return StrCat(instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + tensorflow::str_util::Join( + called_computations(), ", ", + [](string* out, const HloComputation* computation) { + out->append(computation->name()); + })); } CallGraphNode::CallGraphNode(HloComputation* computation) : computation_(computation) {} -void CallGraphNode::AddCallSite(const CallSite& callsite) { - callsites_.push_back(callsite); - HloComputation* callee = callsite.called_computation; - if (callee_set_.count(callee) == 0) { - callees_.push_back(callee); - callee_set_.insert(callee); +const CallSite* CallGraphNode::GetCallSite( + const HloInstruction* instruction) const { + auto it = callsite_instructions_.find(instruction); + if (it == callsite_instructions_.end()) { + return nullptr; } + return &callsites_[it->second]; } void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { caller_callsites_.push_back(caller_callsite); - HloComputation* caller = caller_callsite.instruction->parent(); - if (caller_set_.count(caller) == 0) { + HloComputation* caller = caller_callsite.instruction()->parent(); + if (!ContainsKey(caller_set_, caller)) { callers_.push_back(caller); caller_set_.insert(caller); } } -void CallGraphNode::AddCallSitesInInstruction(HloInstruction* instruction) { +namespace { + +CallContext GetInstructionCallContext(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: - AddCallSite( - {instruction, instruction->to_apply(), CallContext::kSequential}); - break; + case HloOpcode::kWhile: + return CallContext::kSequential; case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - AddCallSite( - {instruction, instruction->to_apply(), CallContext::kParallel}); - break; case HloOpcode::kSelectAndScatter: - AddCallSite({instruction, instruction->select(), CallContext::kParallel}); - AddCallSite( - {instruction, instruction->scatter(), CallContext::kParallel}); - break; - case HloOpcode::kWhile: - AddCallSite({instruction, instruction->while_condition(), - CallContext::kSequential}); - AddCallSite( - {instruction, instruction->while_body(), CallContext::kSequential}); - break; case HloOpcode::kFusion: - for (const auto& fused_instruction : instruction->fused_instructions()) { - AddCallSitesInInstruction(fused_instruction.get()); - } - break; + return CallContext::kParallel; default: - break; + return CallContext::kNone; } } +} // namespace + +Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { + TF_RET_CHECK(instruction->parent() == computation()); + CallContext context = GetInstructionCallContext(instruction); + if (instruction->called_computations().empty()) { + TF_RET_CHECK(context == CallContext::kNone); + } else { + TF_RET_CHECK(context == CallContext::kSequential || + context == CallContext::kParallel); + callsite_instructions_.insert({instruction, callsites_.size()}); + callsites_.push_back( + CallSite(instruction, instruction->called_computations(), context)); + // Update callee computations to include any new computations called by this + // instruction. + for (auto* callee : callsites_.back().called_computations()) { + if (!ContainsKey(callee_set_, callee)) { + callees_.push_back(callee); + callee_set_.insert(callee); + } + } + } + return Status::OK(); +} + CallGraph::CallGraph(const HloModule* module) : module_(module) {} StatusOr<const CallGraphNode*> CallGraph::GetNode( @@ -161,25 +179,26 @@ Status CallGraph::SetCallContexts() { worklist.pop(); for (const CallSite& callsite : node->callsites()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, - GetNode(callsite.called_computation)); - - // Update context of callee computation based on the callsite and its - // current context. - CallContext context_to_add; - if (callsite.context == CallContext::kParallel) { - context_to_add = CallContext::kParallel; - } else { - TF_RET_CHECK(callsite.context == CallContext::kSequential); - context_to_add = node->context(); - } - CallContext new_context = - UnionContexts(context_to_add, callee_node->context()); - - if (new_context != callee_node->context()) { - // Context of computation has been changed so add node to worklist. - callee_node->set_context(new_context); - worklist.push(callee_node); + for (const HloComputation* callee : callsite.called_computations()) { + TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, GetNode(callee)); + + // Update context of callee computation based on the callsite and its + // current context. + CallContext context_to_add; + if (callsite.context() == CallContext::kParallel) { + context_to_add = CallContext::kParallel; + } else { + TF_RET_CHECK(callsite.context() == CallContext::kSequential); + context_to_add = node->context(); + } + CallContext new_context = + UnionContexts(context_to_add, callee_node->context()); + + if (new_context != callee_node->context()) { + // Context of computation has been changed so add node to worklist. + callee_node->set_context(new_context); + worklist.push(callee_node); + } } } } @@ -194,23 +213,25 @@ Status CallGraph::SetCallContexts() { } /* static */ -StatusOr<CallGraph> CallGraph::Build(const HloModule* module) { - CallGraph call_graph(module); +StatusOr<std::unique_ptr<CallGraph>> CallGraph::Build(const HloModule* module) { + // Constructor for CallGraph is private so MakeUnique can't be used. + auto call_graph = WrapUnique<CallGraph>(new CallGraph(module)); // Construct nodes of the call graph and populate the callsites. for (const std::unique_ptr<HloComputation>& computation : module->computations()) { - auto it_added = call_graph.node_indices_.insert( - {computation.get(), call_graph.nodes_.size()}); + auto it_added = call_graph->node_indices_.insert( + {computation.get(), call_graph->nodes_.size()}); // All computation should be unique, so the computation should not already // exist in the map. TF_RET_CHECK(it_added.second); - call_graph.nodes_.emplace_back(computation.get()); + call_graph->nodes_.emplace_back(computation.get()); // Add all callsites in this computation. for (const std::unique_ptr<HloInstruction>& instruction : computation->instructions()) { - call_graph.nodes_.back().AddCallSitesInInstruction(instruction.get()); + TF_RETURN_IF_ERROR(call_graph->nodes_.back().AddCallSiteForInstruction( + instruction.get())); } } @@ -218,25 +239,27 @@ StatusOr<CallGraph> CallGraph::Build(const HloModule* module) { for (const std::unique_ptr<HloComputation>& computation : module->computations()) { TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node, - call_graph.GetNode(computation.get())); + call_graph->GetNode(computation.get())); for (const CallSite& callsite : caller_node->callsites()) { - // Add caller callsites. - TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, - call_graph.GetNode(callsite.called_computation)); - callee_node->AddCallerCallSite(callsite); + for (auto* callee : callsite.called_computations()) { + // Add caller callsites. + TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, + call_graph->GetNode(callee)); + callee_node->AddCallerCallSite(callsite); + } } } - TF_RETURN_IF_ERROR(call_graph.SetCallContexts()); + TF_RETURN_IF_ERROR(call_graph->SetCallContexts()); - XLA_VLOG_LINES(1, call_graph.ToString()); + XLA_VLOG_LINES(1, call_graph->ToString()); return std::move(call_graph); } Status CallGraph::VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode* node, - std::unordered_set<const CallGraphNode*>* visited) const { + tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const { auto pair = visited->insert(node); if (!pair.second) { // Node was not inserted. Node has already been visited. @@ -253,7 +276,7 @@ Status CallGraph::VisitNodesInternal( Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes) const { - std::unordered_set<const CallGraphNode*> visited; + tensorflow::gtl::FlatSet<const CallGraphNode*> visited; if (visit_unreachable_nodes) { // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { |