aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/call_graph.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-03-17 11:27:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-17 12:45:35 -0700
commitd687f9bb40d978912ef8ddd531fa39e032de4c39 (patch)
treee14e1d16f52439e07752501694b388317c54aafd /tensorflow/compiler/xla/service/call_graph.cc
parentc76bc807c3089069d520edd6ab7f3f07a0c76f53 (diff)
[XLA] Add mapping from HloInstruction to CallGraphNode.
Make mapping from calling instruction to CallSite one-to-one by extending CallSite to handle more than one called computation. This enables instructions like kWhile which call two computations to be represented as a single CallSite. Also add a mapping from instruction to CallSite in CallGraphNode to enable fast call site lookup. Also, include a few other opportunistic improvements: * Change CallGraph::Build factor to return a std::unique_ptr. This enables, for example, more convenient use of CallGraph as a data member to a class. * Change a few uses of unordered_set/map to FlatSet/Map. Change: 150469958
Diffstat (limited to 'tensorflow/compiler/xla/service/call_graph.cc')
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc155
1 files changed, 89 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index fdd963a84a..f75aa9bd11 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,10 +17,13 @@ limitations under the License.
#include <queue>
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
@@ -49,64 +52,79 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) {
}
string CallSite::ToString() const {
- return StrCat(instruction->name(), " calls ", called_computation->name(),
- ", ", CallContextToString(context));
+ return StrCat(instruction()->name(), " calls in context ",
+ CallContextToString(context()), ": ",
+ tensorflow::str_util::Join(
+ called_computations(), ", ",
+ [](string* out, const HloComputation* computation) {
+ out->append(computation->name());
+ }));
}
CallGraphNode::CallGraphNode(HloComputation* computation)
: computation_(computation) {}
-void CallGraphNode::AddCallSite(const CallSite& callsite) {
- callsites_.push_back(callsite);
- HloComputation* callee = callsite.called_computation;
- if (callee_set_.count(callee) == 0) {
- callees_.push_back(callee);
- callee_set_.insert(callee);
+const CallSite* CallGraphNode::GetCallSite(
+ const HloInstruction* instruction) const {
+ auto it = callsite_instructions_.find(instruction);
+ if (it == callsite_instructions_.end()) {
+ return nullptr;
}
+ return &callsites_[it->second];
}
void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
caller_callsites_.push_back(caller_callsite);
- HloComputation* caller = caller_callsite.instruction->parent();
- if (caller_set_.count(caller) == 0) {
+ HloComputation* caller = caller_callsite.instruction()->parent();
+ if (!ContainsKey(caller_set_, caller)) {
callers_.push_back(caller);
caller_set_.insert(caller);
}
}
-void CallGraphNode::AddCallSitesInInstruction(HloInstruction* instruction) {
+namespace {
+
+CallContext GetInstructionCallContext(const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kCall:
- AddCallSite(
- {instruction, instruction->to_apply(), CallContext::kSequential});
- break;
+ case HloOpcode::kWhile:
+ return CallContext::kSequential;
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
- AddCallSite(
- {instruction, instruction->to_apply(), CallContext::kParallel});
- break;
case HloOpcode::kSelectAndScatter:
- AddCallSite({instruction, instruction->select(), CallContext::kParallel});
- AddCallSite(
- {instruction, instruction->scatter(), CallContext::kParallel});
- break;
- case HloOpcode::kWhile:
- AddCallSite({instruction, instruction->while_condition(),
- CallContext::kSequential});
- AddCallSite(
- {instruction, instruction->while_body(), CallContext::kSequential});
- break;
case HloOpcode::kFusion:
- for (const auto& fused_instruction : instruction->fused_instructions()) {
- AddCallSitesInInstruction(fused_instruction.get());
- }
- break;
+ return CallContext::kParallel;
default:
- break;
+ return CallContext::kNone;
}
}
+} // namespace
+
+Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
+ TF_RET_CHECK(instruction->parent() == computation());
+ CallContext context = GetInstructionCallContext(instruction);
+ if (instruction->called_computations().empty()) {
+ TF_RET_CHECK(context == CallContext::kNone);
+ } else {
+ TF_RET_CHECK(context == CallContext::kSequential ||
+ context == CallContext::kParallel);
+ callsite_instructions_.insert({instruction, callsites_.size()});
+ callsites_.push_back(
+ CallSite(instruction, instruction->called_computations(), context));
+ // Update callee computations to include any new computations called by this
+ // instruction.
+ for (auto* callee : callsites_.back().called_computations()) {
+ if (!ContainsKey(callee_set_, callee)) {
+ callees_.push_back(callee);
+ callee_set_.insert(callee);
+ }
+ }
+ }
+ return Status::OK();
+}
+
CallGraph::CallGraph(const HloModule* module) : module_(module) {}
StatusOr<const CallGraphNode*> CallGraph::GetNode(
@@ -161,25 +179,26 @@ Status CallGraph::SetCallContexts() {
worklist.pop();
for (const CallSite& callsite : node->callsites()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
- GetNode(callsite.called_computation));
-
- // Update context of callee computation based on the callsite and its
- // current context.
- CallContext context_to_add;
- if (callsite.context == CallContext::kParallel) {
- context_to_add = CallContext::kParallel;
- } else {
- TF_RET_CHECK(callsite.context == CallContext::kSequential);
- context_to_add = node->context();
- }
- CallContext new_context =
- UnionContexts(context_to_add, callee_node->context());
-
- if (new_context != callee_node->context()) {
- // Context of computation has been changed so add node to worklist.
- callee_node->set_context(new_context);
- worklist.push(callee_node);
+ for (const HloComputation* callee : callsite.called_computations()) {
+ TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, GetNode(callee));
+
+ // Update context of callee computation based on the callsite and its
+ // current context.
+ CallContext context_to_add;
+ if (callsite.context() == CallContext::kParallel) {
+ context_to_add = CallContext::kParallel;
+ } else {
+ TF_RET_CHECK(callsite.context() == CallContext::kSequential);
+ context_to_add = node->context();
+ }
+ CallContext new_context =
+ UnionContexts(context_to_add, callee_node->context());
+
+ if (new_context != callee_node->context()) {
+ // Context of computation has been changed so add node to worklist.
+ callee_node->set_context(new_context);
+ worklist.push(callee_node);
+ }
}
}
}
@@ -194,23 +213,25 @@ Status CallGraph::SetCallContexts() {
}
/* static */
-StatusOr<CallGraph> CallGraph::Build(const HloModule* module) {
- CallGraph call_graph(module);
+StatusOr<std::unique_ptr<CallGraph>> CallGraph::Build(const HloModule* module) {
+ // Constructor for CallGraph is private so MakeUnique can't be used.
+ auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
// Construct nodes of the call graph and populate the callsites.
for (const std::unique_ptr<HloComputation>& computation :
module->computations()) {
- auto it_added = call_graph.node_indices_.insert(
- {computation.get(), call_graph.nodes_.size()});
+ auto it_added = call_graph->node_indices_.insert(
+ {computation.get(), call_graph->nodes_.size()});
// All computation should be unique, so the computation should not already
// exist in the map.
TF_RET_CHECK(it_added.second);
- call_graph.nodes_.emplace_back(computation.get());
+ call_graph->nodes_.emplace_back(computation.get());
// Add all callsites in this computation.
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
- call_graph.nodes_.back().AddCallSitesInInstruction(instruction.get());
+ TF_RETURN_IF_ERROR(call_graph->nodes_.back().AddCallSiteForInstruction(
+ instruction.get()));
}
}
@@ -218,25 +239,27 @@ StatusOr<CallGraph> CallGraph::Build(const HloModule* module) {
for (const std::unique_ptr<HloComputation>& computation :
module->computations()) {
TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node,
- call_graph.GetNode(computation.get()));
+ call_graph->GetNode(computation.get()));
for (const CallSite& callsite : caller_node->callsites()) {
- // Add caller callsites.
- TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
- call_graph.GetNode(callsite.called_computation));
- callee_node->AddCallerCallSite(callsite);
+ for (auto* callee : callsite.called_computations()) {
+ // Add caller callsites.
+ TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
+ call_graph->GetNode(callee));
+ callee_node->AddCallerCallSite(callsite);
+ }
}
}
- TF_RETURN_IF_ERROR(call_graph.SetCallContexts());
+ TF_RETURN_IF_ERROR(call_graph->SetCallContexts());
- XLA_VLOG_LINES(1, call_graph.ToString());
+ XLA_VLOG_LINES(1, call_graph->ToString());
return std::move(call_graph);
}
Status CallGraph::VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode* node,
- std::unordered_set<const CallGraphNode*>* visited) const {
+ tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
auto pair = visited->insert(node);
if (!pair.second) {
// Node was not inserted. Node has already been visited.
@@ -253,7 +276,7 @@ Status CallGraph::VisitNodesInternal(
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
bool visit_unreachable_nodes) const {
- std::unordered_set<const CallGraphNode*> visited;
+ tensorflow::gtl::FlatSet<const CallGraphNode*> visited;
if (visit_unreachable_nodes) {
// Traverse from all roots in the call graph.
for (const CallGraphNode& node : nodes()) {