aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/call_graph.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-04-28 14:51:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 16:12:06 -0700
commit765d97a4168429a730862c9898cc936b445a054c (patch)
tree4f3eb5db989746d8faab6e7b9f474bf433cc3b33 /tensorflow/compiler/xla/service/call_graph.cc
parent998baa0f1fa8aee4382be1a00e4ae9ee70a6b194 (diff)
[XLA] Make HLO ordering module-scoped.
Add comparison of ordering of HLO instructions which are in different computations using the call graph. Previously, instructions in different computations were considered unordered. Ordering these instructions improves buffer liveness analysis and may enable better buffer sharing between values in different computations. Change: 154592912
Diffstat (limited to 'tensorflow/compiler/xla/service/call_graph.cc')
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc87
1 files changed, 39 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 57d69f5b71..fa7b2a3095 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -98,12 +98,12 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
}
}
-Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
- TF_RET_CHECK(instruction->parent() == computation());
+void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
+ CHECK_EQ(instruction->parent(), computation());
const CallContext context = GetInstructionCallContext(instruction);
if (!instruction->called_computations().empty()) {
- TF_RET_CHECK(context == CallContext::kSequential ||
- context == CallContext::kParallel);
+ CHECK(context == CallContext::kSequential ||
+ context == CallContext::kParallel);
callsite_instructions_.insert({instruction, callsites_.size()});
callsites_.push_back(
CallSite(instruction, instruction->called_computations(), context));
@@ -116,22 +116,21 @@ Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
}
}
}
- return Status::OK();
}
CallGraph::CallGraph(const HloModule* module) : module_(module) {}
-StatusOr<const CallGraphNode*> CallGraph::GetNode(
+const CallGraphNode& CallGraph::GetNode(
const HloComputation* computation) const {
auto it = node_indices_.find(computation);
- TF_RET_CHECK(it != node_indices_.end());
- return &nodes_[it->second];
+ CHECK(it != node_indices_.end());
+ return nodes_[it->second];
}
-StatusOr<CallGraphNode*> CallGraph::GetNode(const HloComputation* computation) {
+CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
auto it = node_indices_.find(computation);
- TF_RET_CHECK(it != node_indices_.end());
- return &nodes_[it->second];
+ CHECK(it != node_indices_.end());
+ return nodes_[it->second];
}
namespace {
@@ -154,17 +153,17 @@ CallContext UnionContexts(CallContext a, CallContext b) {
} // namespace
-Status CallGraph::SetCallContexts() {
+void CallGraph::SetCallContexts() {
std::queue<CallGraphNode*> worklist;
// Initialize worklist with all roots of the call graph (computations without
// callers).
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get()));
- if (node->callers().empty()) {
- node->set_context(CallContext::kSequential);
- worklist.push(node);
+ CallGraphNode& node = GetNode(computation.get());
+ if (node.callers().empty()) {
+ node.set_context(CallContext::kSequential);
+ worklist.push(&node);
}
}
@@ -174,7 +173,7 @@ Status CallGraph::SetCallContexts() {
for (const CallSite& callsite : node->callsites()) {
for (const HloComputation* callee : callsite.called_computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, GetNode(callee));
+ CallGraphNode& callee_node = GetNode(callee);
// Update context of callee computation based on the callsite and its
// current context.
@@ -182,16 +181,16 @@ Status CallGraph::SetCallContexts() {
if (callsite.context() == CallContext::kParallel) {
context_to_add = CallContext::kParallel;
} else {
- TF_RET_CHECK(callsite.context() == CallContext::kSequential);
+ CHECK_EQ(callsite.context(), CallContext::kSequential);
context_to_add = node->context();
}
CallContext new_context =
- UnionContexts(context_to_add, callee_node->context());
+ UnionContexts(context_to_add, callee_node.context());
- if (new_context != callee_node->context()) {
+ if (new_context != callee_node.context()) {
// Context of computation has been changed so add node to worklist.
- callee_node->set_context(new_context);
- worklist.push(callee_node);
+ callee_node.set_context(new_context);
+ worklist.push(&callee_node);
}
}
}
@@ -200,14 +199,12 @@ Status CallGraph::SetCallContexts() {
// No node should have a kNone calling context.
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get()));
- TF_RET_CHECK(node->context() != CallContext::kNone);
+ CHECK_NE(GetNode(computation.get()).context(), CallContext::kNone);
}
- return Status::OK();
}
/* static */
-StatusOr<std::unique_ptr<CallGraph>> CallGraph::Build(const HloModule* module) {
+std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
// Constructor for CallGraph is private so MakeUnique can't be used.
auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
@@ -221,54 +218,49 @@ StatusOr<std::unique_ptr<CallGraph>> CallGraph::Build(const HloModule* module) {
{computation.get(), call_graph->nodes_.size()});
// All computations should be unique, so the computation should not already
// exist in the map.
- TF_RET_CHECK(it_added.second);
+ CHECK(it_added.second);
call_graph->nodes_.emplace_back(computation.get());
// Add all callsites in this computation.
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
- TF_RETURN_IF_ERROR(call_graph->nodes_.back().AddCallSiteForInstruction(
- instruction.get()));
+ call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get());
}
}
// Add caller callsites to each node.
for (const std::unique_ptr<HloComputation>& computation :
module->computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node,
- call_graph->GetNode(computation.get()));
- for (const CallSite& callsite : caller_node->callsites()) {
+ for (const CallSite& callsite :
+ call_graph->GetNode(computation.get()).callsites()) {
for (auto* callee : callsite.called_computations()) {
// Add caller callsites.
- TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
- call_graph->GetNode(callee));
- callee_node->AddCallerCallSite(callsite);
+ call_graph->GetNode(callee).AddCallerCallSite(callsite);
}
}
}
- TF_RETURN_IF_ERROR(call_graph->SetCallContexts());
-
+ call_graph->SetCallContexts();
XLA_VLOG_LINES(1, call_graph->ToString());
- return std::move(call_graph);
+ return call_graph;
}
Status CallGraph::VisitNodesInternal(
- const VisitorFunction& visitor_func, const CallGraphNode* node,
+ const VisitorFunction& visitor_func, const CallGraphNode& node,
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
- auto pair = visited->insert(node);
+ auto pair = visited->insert(&node);
if (!pair.second) {
// Node was not inserted. Node has already been visited.
return Status::OK();
}
- for (const HloComputation* computation : node->callees()) {
- TF_ASSIGN_OR_RETURN(const CallGraphNode* callee_node, GetNode(computation));
- TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, callee_node, visited));
+ for (const HloComputation* computation : node.callees()) {
+ TF_RETURN_IF_ERROR(
+ VisitNodesInternal(visitor_func, GetNode(computation), visited));
}
- return visitor_func(*node);
+ return visitor_func(node);
}
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
@@ -278,14 +270,13 @@ Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
// Traverse from all roots in the call graph.
for (const CallGraphNode& node : nodes()) {
if (node.callers().empty()) {
- TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, &node, &visited));
+ TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
}
}
} else {
// Traverse only from the entry computation.
- TF_ASSIGN_OR_RETURN(const CallGraphNode* entry_node,
- GetNode(module_->entry_computation()));
- TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, entry_node, &visited));
+ TF_RETURN_IF_ERROR(VisitNodesInternal(
+ visitor_func, GetNode(module_->entry_computation()), &visited));
}
return Status::OK();