aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-04-28 14:51:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 16:12:06 -0700
commit765d97a4168429a730862c9898cc936b445a054c (patch)
tree4f3eb5db989746d8faab6e7b9f474bf433cc3b33 /tensorflow/compiler/xla
parent998baa0f1fa8aee4382be1a00e4ae9ee70a6b194 (diff)
[XLA] Make HLO ordering module-scoped.
Add comparison of ordering of HLO instructions which are in different computations using the call graph. Previously, instructions in different computations were considered unordered. Ordering these instructions improves buffer liveness analysis and may enable better buffer sharing between values in different computations. Change: 154592912
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc87
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h14
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc159
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.cc3
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc103
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc77
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc12
11 files changed, 333 insertions, 215 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e17205be23..2137888726 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -195,7 +195,6 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
@@ -624,6 +623,7 @@ cc_library(
"buffer_liveness.h",
],
deps = [
+ ":call_graph",
":hlo",
":hlo_ordering",
":liveness_util",
@@ -747,6 +747,7 @@ cc_library(
"hlo_ordering.h",
],
deps = [
+ ":call_graph",
":heap_simulator",
":hlo",
":logical_buffer",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 0d6e89c5c6..ac1d769010 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -856,8 +856,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
EXPECT_FALSE(map_root_alloc.maybe_live_out());
EXPECT_TRUE(map_root_alloc.is_thread_local());
- // Allocations for the call computation should not be thread-local and not
- // live-out.
+ // Allocations for the call computation should not be thread-local.
auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter());
EXPECT_FALSE(call_param_alloc.maybe_live_out());
@@ -865,7 +864,6 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
- EXPECT_FALSE(call_root_alloc.maybe_live_out());
EXPECT_FALSE(call_root_alloc.is_thread_local());
// Entry computation allocations can be marked liveout and
@@ -1445,8 +1443,7 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
FlattenCallGraph flatten;
TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get()));
EXPECT_TRUE(result);
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(module.get()));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
}
RunCopyInsertion(module.get());
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 57d69f5b71..fa7b2a3095 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -98,12 +98,12 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
}
}
-Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
- TF_RET_CHECK(instruction->parent() == computation());
+void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
+ CHECK_EQ(instruction->parent(), computation());
const CallContext context = GetInstructionCallContext(instruction);
if (!instruction->called_computations().empty()) {
- TF_RET_CHECK(context == CallContext::kSequential ||
- context == CallContext::kParallel);
+ CHECK(context == CallContext::kSequential ||
+ context == CallContext::kParallel);
callsite_instructions_.insert({instruction, callsites_.size()});
callsites_.push_back(
CallSite(instruction, instruction->called_computations(), context));
@@ -116,22 +116,21 @@ Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
}
}
}
- return Status::OK();
}
CallGraph::CallGraph(const HloModule* module) : module_(module) {}
-StatusOr<const CallGraphNode*> CallGraph::GetNode(
+const CallGraphNode& CallGraph::GetNode(
const HloComputation* computation) const {
auto it = node_indices_.find(computation);
- TF_RET_CHECK(it != node_indices_.end());
- return &nodes_[it->second];
+ CHECK(it != node_indices_.end());
+ return nodes_[it->second];
}
-StatusOr<CallGraphNode*> CallGraph::GetNode(const HloComputation* computation) {
+CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
auto it = node_indices_.find(computation);
- TF_RET_CHECK(it != node_indices_.end());
- return &nodes_[it->second];
+ CHECK(it != node_indices_.end());
+ return nodes_[it->second];
}
namespace {
@@ -154,17 +153,17 @@ CallContext UnionContexts(CallContext a, CallContext b) {
} // namespace
-Status CallGraph::SetCallContexts() {
+void CallGraph::SetCallContexts() {
std::queue<CallGraphNode*> worklist;
// Initialize worklist with all roots of the call graph (computations without
// callers).
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get()));
- if (node->callers().empty()) {
- node->set_context(CallContext::kSequential);
- worklist.push(node);
+ CallGraphNode& node = GetNode(computation.get());
+ if (node.callers().empty()) {
+ node.set_context(CallContext::kSequential);
+ worklist.push(&node);
}
}
@@ -174,7 +173,7 @@ Status CallGraph::SetCallContexts() {
for (const CallSite& callsite : node->callsites()) {
for (const HloComputation* callee : callsite.called_computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, GetNode(callee));
+ CallGraphNode& callee_node = GetNode(callee);
// Update context of callee computation based on the callsite and its
// current context.
@@ -182,16 +181,16 @@ Status CallGraph::SetCallContexts() {
if (callsite.context() == CallContext::kParallel) {
context_to_add = CallContext::kParallel;
} else {
- TF_RET_CHECK(callsite.context() == CallContext::kSequential);
+ CHECK_EQ(callsite.context(), CallContext::kSequential);
context_to_add = node->context();
}
CallContext new_context =
- UnionContexts(context_to_add, callee_node->context());
+ UnionContexts(context_to_add, callee_node.context());
- if (new_context != callee_node->context()) {
+ if (new_context != callee_node.context()) {
// Context of computation has been changed so add node to worklist.
- callee_node->set_context(new_context);
- worklist.push(callee_node);
+ callee_node.set_context(new_context);
+ worklist.push(&callee_node);
}
}
}
@@ -200,14 +199,12 @@ Status CallGraph::SetCallContexts() {
// No node should have a kNone calling context.
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get()));
- TF_RET_CHECK(node->context() != CallContext::kNone);
+ CHECK_NE(GetNode(computation.get()).context(), CallContext::kNone);
}
- return Status::OK();
}
/* static */
-StatusOr<std::unique_ptr<CallGraph>> CallGraph::Build(const HloModule* module) {
+std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
// Constructor for CallGraph is private so MakeUnique can't be used.
auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
@@ -221,54 +218,49 @@ StatusOr<std::unique_ptr<CallGraph>> CallGraph::Build(const HloModule* module) {
{computation.get(), call_graph->nodes_.size()});
// All computations should be unique, so the computation should not already
// exist in the map.
- TF_RET_CHECK(it_added.second);
+ CHECK(it_added.second);
call_graph->nodes_.emplace_back(computation.get());
// Add all callsites in this computation.
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
- TF_RETURN_IF_ERROR(call_graph->nodes_.back().AddCallSiteForInstruction(
- instruction.get()));
+ call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get());
}
}
// Add caller callsites to each node.
for (const std::unique_ptr<HloComputation>& computation :
module->computations()) {
- TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node,
- call_graph->GetNode(computation.get()));
- for (const CallSite& callsite : caller_node->callsites()) {
+ for (const CallSite& callsite :
+ call_graph->GetNode(computation.get()).callsites()) {
for (auto* callee : callsite.called_computations()) {
// Add caller callsites.
- TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
- call_graph->GetNode(callee));
- callee_node->AddCallerCallSite(callsite);
+ call_graph->GetNode(callee).AddCallerCallSite(callsite);
}
}
}
- TF_RETURN_IF_ERROR(call_graph->SetCallContexts());
-
+ call_graph->SetCallContexts();
XLA_VLOG_LINES(1, call_graph->ToString());
- return std::move(call_graph);
+ return call_graph;
}
Status CallGraph::VisitNodesInternal(
- const VisitorFunction& visitor_func, const CallGraphNode* node,
+ const VisitorFunction& visitor_func, const CallGraphNode& node,
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
- auto pair = visited->insert(node);
+ auto pair = visited->insert(&node);
if (!pair.second) {
// Node was not inserted. Node has already been visited.
return Status::OK();
}
- for (const HloComputation* computation : node->callees()) {
- TF_ASSIGN_OR_RETURN(const CallGraphNode* callee_node, GetNode(computation));
- TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, callee_node, visited));
+ for (const HloComputation* computation : node.callees()) {
+ TF_RETURN_IF_ERROR(
+ VisitNodesInternal(visitor_func, GetNode(computation), visited));
}
- return visitor_func(*node);
+ return visitor_func(node);
}
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
@@ -278,14 +270,13 @@ Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
// Traverse from all roots in the call graph.
for (const CallGraphNode& node : nodes()) {
if (node.callers().empty()) {
- TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, &node, &visited));
+ TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
}
}
} else {
// Traverse only from the entry computation.
- TF_ASSIGN_OR_RETURN(const CallGraphNode* entry_node,
- GetNode(module_->entry_computation()));
- TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, entry_node, &visited));
+ TF_RETURN_IF_ERROR(VisitNodesInternal(
+ visitor_func, GetNode(module_->entry_computation()), &visited));
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 62d12f8f91..7f9990f06d 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -138,7 +137,7 @@ class CallGraphNode {
// If instruction calls any computations adds a call site for this instruction
// to the call graph node. If the instruction calls no computations then no
// call site is added.
- Status AddCallSiteForInstruction(HloInstruction* instruction);
+ void AddCallSiteForInstruction(HloInstruction* instruction);
// Computation represented by this call graph node.
HloComputation* computation_;
@@ -174,12 +173,11 @@ class CallGraph {
using VisitorFunction = std::function<Status(const CallGraphNode&)>;
// Builds and returns a call graph for the given HLO module.
- static StatusOr<std::unique_ptr<CallGraph>> Build(const HloModule* module);
+ static std::unique_ptr<CallGraph> Build(const HloModule* module);
// Returns the node associated with the given computation.
- StatusOr<const CallGraphNode*> GetNode(
- const HloComputation* computation) const;
- StatusOr<CallGraphNode*> GetNode(const HloComputation* computation);
+ const CallGraphNode& GetNode(const HloComputation* computation) const;
+ CallGraphNode& GetNode(const HloComputation* computation);
// Returns the vector of all nodes in the call graph.
const std::vector<CallGraphNode>& nodes() const { return nodes_; }
@@ -197,14 +195,14 @@ class CallGraph {
CallGraph(const HloModule* module);
// Sets the call contexts for every node in the graph.
- Status SetCallContexts();
+ void SetCallContexts();
// Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS
// post order (callee before caller) calling visitor_func on each node. Adds
// nodes to 'visited' as each node is visited. Skips nodes already in
// 'visited'.
Status VisitNodesInternal(
- const VisitorFunction& visitor_func, const CallGraphNode* node,
+ const VisitorFunction& visitor_func, const CallGraphNode& node,
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
// The HLO module represented by this call graph.
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index f71a5d01af..ab0ea47d02 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -95,17 +95,15 @@ TEST_F(CallGraphTest, SingletonComputation) {
HloModule module(TestName());
HloComputation* computation =
module.AddEntryComputation(MakeScalarComputation());
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(1, call_graph->nodes().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* node,
- call_graph->GetNode(computation));
- EXPECT_EQ(computation, node->computation());
- EXPECT_TRUE(node->callsites().empty());
- EXPECT_TRUE(node->callees().empty());
- EXPECT_TRUE(node->caller_callsites().empty());
- EXPECT_TRUE(node->callers().empty());
- EXPECT_EQ(CallContext::kSequential, node->context());
+ const CallGraphNode& node = call_graph->GetNode(computation);
+ EXPECT_EQ(computation, node.computation());
+ EXPECT_TRUE(node.callsites().empty());
+ EXPECT_TRUE(node.callees().empty());
+ EXPECT_TRUE(node.caller_callsites().empty());
+ EXPECT_TRUE(node.callers().empty());
+ EXPECT_EQ(CallContext::kSequential, node.context());
}
TEST_F(CallGraphTest, UnreachableComputation) {
@@ -117,19 +115,17 @@ TEST_F(CallGraphTest, UnreachableComputation) {
HloComputation* unreachable_computation =
module.AddEmbeddedComputation(MakeScalarComputation());
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(2, call_graph->nodes().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
- call_graph->GetNode(entry_computation));
- EXPECT_EQ(entry_computation, entry_node->computation());
- EXPECT_EQ(CallContext::kSequential, entry_node->context());
+ const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
+ EXPECT_EQ(entry_computation, entry_node.computation());
+ EXPECT_EQ(CallContext::kSequential, entry_node.context());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* unreachable_node,
- call_graph->GetNode(unreachable_computation));
- EXPECT_EQ(unreachable_computation, unreachable_node->computation());
- EXPECT_EQ(CallContext::kSequential, unreachable_node->context());
+ const CallGraphNode& unreachable_node =
+ call_graph->GetNode(unreachable_computation);
+ EXPECT_EQ(unreachable_computation, unreachable_node.computation());
+ EXPECT_EQ(CallContext::kSequential, unreachable_node.context());
}
TEST_F(CallGraphTest, ParallelComputation) {
@@ -141,27 +137,24 @@ TEST_F(CallGraphTest, ParallelComputation) {
HloComputation* entry_computation = module.AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(2, call_graph->nodes().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
- call_graph->GetNode(entry_computation));
- EXPECT_EQ(entry_computation, entry_node->computation());
- EXPECT_EQ(CallContext::kSequential, entry_node->context());
- EXPECT_EQ(5, entry_node->callsites().size());
- EXPECT_EQ(1, entry_node->callees().size());
- EXPECT_TRUE(entry_node->caller_callsites().empty());
- EXPECT_TRUE(entry_node->callers().empty());
-
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* map_node,
- call_graph->GetNode(map_computation));
- EXPECT_EQ(map_computation, map_node->computation());
- EXPECT_EQ(CallContext::kParallel, map_node->context());
- EXPECT_TRUE(map_node->callsites().empty());
- EXPECT_TRUE(map_node->callees().empty());
- EXPECT_EQ(5, map_node->caller_callsites().size());
- EXPECT_EQ(1, map_node->callers().size());
+ const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
+ EXPECT_EQ(entry_computation, entry_node.computation());
+ EXPECT_EQ(CallContext::kSequential, entry_node.context());
+ EXPECT_EQ(5, entry_node.callsites().size());
+ EXPECT_EQ(1, entry_node.callees().size());
+ EXPECT_TRUE(entry_node.caller_callsites().empty());
+ EXPECT_TRUE(entry_node.callers().empty());
+
+ const CallGraphNode& map_node = call_graph->GetNode(map_computation);
+ EXPECT_EQ(map_computation, map_node.computation());
+ EXPECT_EQ(CallContext::kParallel, map_node.context());
+ EXPECT_TRUE(map_node.callsites().empty());
+ EXPECT_TRUE(map_node.callees().empty());
+ EXPECT_EQ(5, map_node.caller_callsites().size());
+ EXPECT_EQ(1, map_node.callers().size());
}
TEST_F(CallGraphTest, SequentialComputations) {
@@ -173,27 +166,24 @@ TEST_F(CallGraphTest, SequentialComputations) {
HloComputation* entry_computation = module.AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(2, call_graph->nodes().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
- call_graph->GetNode(entry_computation));
- EXPECT_EQ(entry_computation, entry_node->computation());
- EXPECT_EQ(CallContext::kSequential, entry_node->context());
- EXPECT_EQ(3, entry_node->callsites().size());
- EXPECT_EQ(1, entry_node->callees().size());
- EXPECT_TRUE(entry_node->caller_callsites().empty());
- EXPECT_TRUE(entry_node->callers().empty());
-
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* called_node,
- call_graph->GetNode(called_computation));
- EXPECT_EQ(called_computation, called_node->computation());
- EXPECT_EQ(CallContext::kSequential, called_node->context());
- EXPECT_TRUE(called_node->callsites().empty());
- EXPECT_TRUE(called_node->callees().empty());
- EXPECT_EQ(3, called_node->caller_callsites().size());
- EXPECT_EQ(1, called_node->callers().size());
+ const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
+ EXPECT_EQ(entry_computation, entry_node.computation());
+ EXPECT_EQ(CallContext::kSequential, entry_node.context());
+ EXPECT_EQ(3, entry_node.callsites().size());
+ EXPECT_EQ(1, entry_node.callees().size());
+ EXPECT_TRUE(entry_node.caller_callsites().empty());
+ EXPECT_TRUE(entry_node.callers().empty());
+
+ const CallGraphNode& called_node = call_graph->GetNode(called_computation);
+ EXPECT_EQ(called_computation, called_node.computation());
+ EXPECT_EQ(CallContext::kSequential, called_node.context());
+ EXPECT_TRUE(called_node.callsites().empty());
+ EXPECT_TRUE(called_node.callees().empty());
+ EXPECT_EQ(3, called_node.caller_callsites().size());
+ EXPECT_EQ(1, called_node.callers().size());
}
TEST_F(CallGraphTest, ContextBothComputations) {
@@ -213,32 +203,29 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloComputation* entry_computation =
module.AddEntryComputation(builder.Build());
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(2, call_graph->nodes().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
- call_graph->GetNode(entry_computation));
- EXPECT_EQ(entry_computation, entry_node->computation());
- EXPECT_EQ(2, entry_node->callsites().size());
+ const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
+ EXPECT_EQ(entry_computation, entry_node.computation());
+ EXPECT_EQ(2, entry_node.callsites().size());
- const CallSite& call_callsite = entry_node->callsites()[0];
+ const CallSite& call_callsite = entry_node.callsites()[0];
EXPECT_EQ(call, call_callsite.instruction());
EXPECT_THAT(call_callsite.called_computations(),
UnorderedElementsAre(subcomputation));
EXPECT_EQ(CallContext::kSequential, call_callsite.context());
- EXPECT_EQ(entry_node->GetCallSite(call), &call_callsite);
+ EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
- const CallSite& map_callsite = entry_node->callsites()[1];
+ const CallSite& map_callsite = entry_node.callsites()[1];
EXPECT_EQ(map, map_callsite.instruction());
EXPECT_THAT(map_callsite.called_computations(),
UnorderedElementsAre(subcomputation));
EXPECT_EQ(CallContext::kParallel, map_callsite.context());
- EXPECT_EQ(entry_node->GetCallSite(map), &map_callsite);
+ EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* sub_node,
- call_graph->GetNode(subcomputation));
- EXPECT_EQ(CallContext::kBoth, sub_node->context());
+ const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
+ EXPECT_EQ(CallContext::kBoth, sub_node.context());
}
TEST_F(CallGraphTest, ComplexGraph) {
@@ -284,27 +271,24 @@ TEST_F(CallGraphTest, ComplexGraph) {
entry_computation = module.AddEntryComputation(builder.Build());
}
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(5, call_graph->nodes().size());
// Entry computation has one while instruction calling two computations
// (cond_computation and a_computation).
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
- call_graph->GetNode(entry_computation));
- ASSERT_EQ(1, entry_node->callsites().size());
+ const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
+ ASSERT_EQ(1, entry_node.callsites().size());
const std::vector<HloComputation*>& called_computations =
- entry_node->callsites()[0].called_computations();
+ entry_node.callsites()[0].called_computations();
EXPECT_THAT(called_computations,
UnorderedElementsAre(cond_computation, a_computation));
- EXPECT_EQ(CallContext::kSequential, entry_node->context());
+ EXPECT_EQ(CallContext::kSequential, entry_node.context());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node,
- call_graph->GetNode(c_computation));
- EXPECT_TRUE(c_node->callsites().empty());
- EXPECT_THAT(c_node->callers(),
+ const CallGraphNode& c_node = call_graph->GetNode(c_computation);
+ EXPECT_TRUE(c_node.callsites().empty());
+ EXPECT_THAT(c_node.callers(),
UnorderedElementsAre(a_computation, b_computation));
- EXPECT_EQ(CallContext::kBoth, c_node->context());
+ EXPECT_EQ(CallContext::kBoth, c_node.context());
// Visit the graph and verify nodes were visited in callee-before-caller
// order.
@@ -337,8 +321,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
HloModule module(TestName());
HloComputation* computation =
module.AddEntryComputation(MakeScalarComputation());
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -355,8 +338,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
module.AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module.AddEmbeddedComputation(MakeScalarComputation());
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
// Test visitation of only reachable nodes.
{
@@ -390,8 +372,7 @@ TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
HloModule module(TestName());
module.AddEntryComputation(MakeScalarComputation());
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc
index 3c41fe870f..297a4f7599 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc
@@ -102,8 +102,7 @@ Status FlattenNode(const CallGraphNode& node) {
StatusOr<bool> FlattenCallGraph::Run(HloModule* module) {
XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode));
XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString());
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 6c4a48bbe8..4e03a96fb3 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -141,11 +141,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
{
TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module));
EXPECT_TRUE(result);
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> flat_call_graph,
- CallGraph::Build(&module));
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node,
- flat_call_graph->GetNode(c_computation));
- EXPECT_EQ(1, c_node->caller_callsites().size());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(&module);
+ const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
+ EXPECT_EQ(1, c_node.caller_callsites().size());
}
}
@@ -178,21 +176,17 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* cond_node,
- call_graph->GetNode(cond_computation));
- EXPECT_EQ(2, cond_node->caller_callsites().size());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
+ EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module));
EXPECT_TRUE(result);
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* cond_node,
- call_graph->GetNode(cond_computation));
- EXPECT_EQ(1, cond_node->caller_callsites().size());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
+ const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
+ EXPECT_EQ(1, cond_node.caller_callsites().size());
}
}
@@ -219,17 +213,14 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module));
EXPECT_TRUE(result);
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<CallGraph> call_graph,
- CallGraph::Build(&module));
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
EXPECT_EQ(7, module.computations().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node,
- call_graph->GetNode(c_computation));
- EXPECT_EQ(1, c_node->caller_callsites().size());
+ const CallGraphNode& c_node = call_graph->GetNode(c_computation);
+ EXPECT_EQ(1, c_node.caller_callsites().size());
- TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* b_node,
- call_graph->GetNode(b_computation));
- EXPECT_EQ(1, b_node->caller_callsites().size());
+ const CallGraphNode& b_node = call_graph->GetNode(b_computation);
+ EXPECT_EQ(1, b_node.caller_callsites().size());
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index b3168ed40e..7476b72f02 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -34,15 +34,95 @@ limitations under the License.
namespace xla {
-PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
- : module_(module) {}
+namespace {
-bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a,
- const HloInstruction* b) const {
- // Instructions in different computations are unordered.
- if (a->parent() != b->parent()) {
+// Returns the nearest call graph ancestors of instructions 'a' and 'b' for
+// which the ancestors are in the same computation. An instruction is an call
+// graph ancestor of 'a' if the instruction calls the computation containing 'a'
+// either directly or transitively. Degeneratively an instruction is an ancestor
+// of itself. nullptr is returned if there is no common ancestor or if the
+// caller chain of 'a' or 'b' diverges (has multiple callers) before the nearest
+// common ancestor.
+//
+// Example:
+//
+// Entry computation:
+// %x = Call(A, {Constant(42.0)})
+// %y = Call(B, {%x})
+//
+// Computation A:
+// %a = Negate(Param())
+//
+// Computation B:
+// %b = Exp(Param());
+//
+// If called with %a and %b, this function would return (%x, %y). %x is an
+// ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same
+// computation.
+std::pair<const HloInstruction*, const HloInstruction*>
+GetNearestCallGraphAncestorsInSameComputation(const HloInstruction* a,
+ const HloInstruction* b,
+ const CallGraph& call_graph) {
+ // Lambda which returns the next instruction in the callee->caller chain in
+ // the call graph. This is the unique instruction which calls the computation
+ // containing 'instruction'. If more than one instruction calls the
+ // computation containing 'instruction' or no instructions call the
+ // computation then nullptr is returned.
+ auto next_caller =
+ [&call_graph](
+ const HloInstruction* instruction) -> const HloInstruction* {
+ const CallGraphNode& node = call_graph.GetNode(instruction->parent());
+ if (node.caller_callsites().size() != 1) {
+ return nullptr;
+ }
+ return node.caller_callsites()[0].instruction();
+ };
+
+ // Iterate through the callee->caller chains and find the earliest common
+ // element.
+ for (const HloInstruction* a_ancestor = a; a_ancestor != nullptr;
+ a_ancestor = next_caller(a_ancestor)) {
+ for (const HloInstruction* b_ancestor = b; b_ancestor != nullptr;
+ b_ancestor = next_caller(b_ancestor)) {
+ if (a_ancestor->parent() == b_ancestor->parent()) {
+ return {a_ancestor, b_ancestor};
+ }
+ }
+ }
+ return {nullptr, nullptr};
+}
+
+} // namespace
+
+bool HloOrdering::ExecutesBefore(const HloInstruction* a,
+ const HloInstruction* b) const {
+ // 'a' and 'b' may be in different computations. In this case, find the
+ // callgraph ancestor instructions which call (potentially transitively) the
+ // computations containing 'a' and 'b' and use these ancestor instructions to
+ // compare order.
+ const HloInstruction* a_ancestor;
+ const HloInstruction* b_ancestor;
+ std::tie(a_ancestor, b_ancestor) =
+ GetNearestCallGraphAncestorsInSameComputation(a, b, *call_graph_);
+
+ if (a_ancestor == nullptr) {
+ // Ancestors in a common computation could not be found so consider the
+ // instructions 'a' and 'b' to be unordered.
return false;
}
+ // a_ancestor and b_ancestor must be either both null or both non-null.
+ CHECK_NE(b_ancestor, nullptr);
+ CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
+ return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
+}
+
+PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
+ : HloOrdering(module) {}
+
+bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
+ const HloInstruction* a, const HloInstruction* b) const {
+ CHECK_EQ(a->parent(), b->parent());
+
// 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
return strict_predecessors_.at(b->parent())->IsReachable(b, a);
}
@@ -86,7 +166,7 @@ string DependencyHloOrdering::ToString() const {
SequentialHloOrdering::SequentialHloOrdering(
const HloModule* module, const HloModuleSequence& module_sequence)
- : module_(module), module_sequence_(module_sequence) {
+ : HloOrdering(module), module_sequence_(module_sequence) {
// Create a map from instruction to its order position.
for (auto computation_order : module_sequence_) {
const std::vector<const HloInstruction*>& order = computation_order.second;
@@ -97,12 +177,9 @@ SequentialHloOrdering::SequentialHloOrdering(
}
}
-bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a,
- const HloInstruction* b) const {
- // Instructions in different computations are unordered.
- if (a->parent() != b->parent()) {
- return false;
- }
+bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
+ const HloInstruction* a, const HloInstruction* b) const {
+ CHECK_EQ(a->parent(), b->parent());
// If either instruction is not in the order, then 'a' and 'b' are unordered.
if (order_position_.count(a) == 0 || order_position_.count(b) == 0) {
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index e964c4c51a..d2db18be00 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -36,13 +37,13 @@ namespace xla {
// buffers.
class HloOrdering {
public:
- HloOrdering() = default;
+ HloOrdering(const HloModule* module)
+ : module_(module), call_graph_(CallGraph::Build(module)) {}
virtual ~HloOrdering() = default;
// Returns true if instruction 'a' executes before instruction 'b'. This is
// not reflexive, that is, an instruction does not execute before itself.
- virtual bool ExecutesBefore(const HloInstruction* a,
- const HloInstruction* b) const = 0;
+ bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const;
// Returns the sequential instruction order for the given computation, or
// nullptr if the computation does not have a sequential ordering.
@@ -50,6 +51,21 @@ class HloOrdering {
const HloComputation& computation) const = 0;
virtual string ToString() const = 0;
+
+ protected:
+ // Returns true if instruction 'a' executes before instruction 'b'.
+ // Precondition: 'a' and 'b' are in the same computation.
+ //
+ // Derived classes should implement this method for determining order of
+ // instructions in the same comptuation. ExecutesBefore() analyzes the
+ // callgraph and uses this method to determine ordering of instructions in
+ // different computations.
+ virtual bool ExecutesBeforeInSameComputation(
+ const HloInstruction* a, const HloInstruction* b) const = 0;
+
+ const HloModule* module_;
+
+ std::unique_ptr<CallGraph> call_graph_;
};
// Base class for partial orderings implemented by a map of strict predecessors
@@ -58,11 +74,6 @@ class PredecessorHloOrdering : public HloOrdering {
public:
~PredecessorHloOrdering() override = default;
- // Returns true if instruction 'a' executes before instruction 'b'.
- // Instructions in different computations are not ordered.
- bool ExecutesBefore(const HloInstruction* a,
- const HloInstruction* b) const override;
-
// Returns nullptr indicating the computation does not have a sequential
// ordering.
const std::vector<const HloInstruction*>* SequentialOrder(
@@ -74,11 +85,12 @@ class PredecessorHloOrdering : public HloOrdering {
explicit PredecessorHloOrdering(const HloModule* module);
string ToStringHelper(const string& name) const;
- const HloModule* module_;
+ bool ExecutesBeforeInSameComputation(const HloInstruction* a,
+ const HloInstruction* b) const override;
- // For each each computation in the module, this is the set of the
- // instruction's strict predecessors. An instruction is not an element of its
- // own strict predecessor set.
+ // For each computation in the module, this is the set of the instruction's
+ // strict predecessors. An instruction is not an element of its own strict
+ // predecessor set.
//
// Subclasses should fill this in to define the desired ordering.
tensorflow::gtl::FlatMap<const HloComputation*,
@@ -150,12 +162,6 @@ class SequentialHloOrdering : public HloOrdering {
const HloModuleSequence& module_sequence);
~SequentialHloOrdering() override = default;
- // Instruction 'a' executes before 'b' if 'a' appears before 'b' in the
- // instruction sequence for the computation. Instructions in different
- // computations are unordered.
- bool ExecutesBefore(const HloInstruction* a,
- const HloInstruction* b) const override;
-
// Returns the sequential instruction order for the given computation.
const std::vector<const HloInstruction*>* SequentialOrder(
const HloComputation& computation) const override;
@@ -163,7 +169,9 @@ class SequentialHloOrdering : public HloOrdering {
string ToString() const override;
protected:
- const HloModule* module_;
+ bool ExecutesBeforeInSameComputation(const HloInstruction* a,
+ const HloInstruction* b) const override;
+
const HloModuleSequence module_sequence_;
// The position of every instruction in the HLO module in its respective
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 425bee601a..01b5fd9364 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -78,6 +78,83 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) {
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
}
+TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
+ // Tests the ordering of instructions in different computations using the
+ // following HLO code:
+ //
+ // Entry computation:
+ // %x = Call(A, {})
+ // %y = Call(B, {%x})
+ //
+ // Computation A:
+ // %a = Call(C, {})
+ //
+ // Computation B:
+ // %b = Call(C, {})
+ //
+ // Computation C:
+ // %c = Constant(42.0f)
+ //
+ // This results in a diamond-shaped callgraph.
+ HloModule module(TestName());
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto builder_c = HloComputation::Builder("C");
+ HloInstruction* c = builder_c.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ HloComputation* computation_c =
+ module.AddEmbeddedComputation(builder_c.Build());
+
+ auto builder_b = HloComputation::Builder("B");
+ builder_b.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* b = builder_b.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {}, computation_c));
+ HloComputation* computation_b =
+ module.AddEmbeddedComputation(builder_b.Build());
+
+ auto builder_a = HloComputation::Builder("A");
+ HloInstruction* a = builder_a.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {}, computation_c));
+ HloComputation* computation_a =
+ module.AddEmbeddedComputation(builder_a.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* x = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {}, computation_a));
+ HloInstruction* y = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {x}, computation_b));
+ module.AddEntryComputation(builder.Build());
+
+ DependencyHloOrdering ordering(&module);
+ EXPECT_TRUE(ordering.ExecutesBefore(x, y));
+ EXPECT_FALSE(ordering.ExecutesBefore(y, x));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(a, b));
+ EXPECT_FALSE(ordering.ExecutesBefore(b, a));
+
+ EXPECT_FALSE(ordering.ExecutesBefore(a, x));
+ EXPECT_TRUE(ordering.ExecutesBefore(a, y));
+ EXPECT_FALSE(ordering.ExecutesBefore(x, a));
+ EXPECT_FALSE(ordering.ExecutesBefore(y, a));
+
+ EXPECT_FALSE(ordering.ExecutesBefore(b, x));
+ EXPECT_FALSE(ordering.ExecutesBefore(b, y));
+ EXPECT_TRUE(ordering.ExecutesBefore(x, b));
+ EXPECT_FALSE(ordering.ExecutesBefore(y, b));
+
+ // Instruction 'c' is called from multiple callsites and should be unordered
+ // relative to all other instructions in the module.
+ EXPECT_FALSE(ordering.ExecutesBefore(c, a));
+ EXPECT_FALSE(ordering.ExecutesBefore(c, b));
+ EXPECT_FALSE(ordering.ExecutesBefore(c, x));
+ EXPECT_FALSE(ordering.ExecutesBefore(c, y));
+ EXPECT_FALSE(ordering.ExecutesBefore(a, c));
+ EXPECT_FALSE(ordering.ExecutesBefore(b, c));
+ EXPECT_FALSE(ordering.ExecutesBefore(x, c));
+ EXPECT_FALSE(ordering.ExecutesBefore(y, c));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 65c61627a1..44293f582e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -930,9 +930,8 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
const HloInstruction* instruction) const {
- TF_ASSIGN_OR_RETURN(const CallGraphNode* node,
- call_graph_->GetNode(instruction->parent()));
- const CallSite* callsite = node->GetCallSite(instruction);
+ const CallSite* callsite =
+ call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
return 0;
}
@@ -981,8 +980,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// instructions which are dead.
int64 net_instructions_added = 0;
- TF_ASSIGN_OR_RETURN(const CallGraphNode* call_graph_node,
- call_graph_->GetNode(computation));
+ const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
// Iterate through all instructions in the sequence. At each instruction
// (program point) if memory_usage exceeds the specified limit then
@@ -1080,7 +1078,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< memory_tracker.memory_usage();
}
- const CallSite* callsite = call_graph_node->GetCallSite(instruction);
+ const CallSite* callsite = call_graph_node.GetCallSite(instruction);
if (callsite != nullptr &&
callsite->context() == CallContext::kSequential &&
memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
@@ -1194,7 +1192,7 @@ StatusOr<bool> HloRematerialization::Run(
}));
// Compute peak memory usage of all computations in the module called in a
// sequential context.
- TF_ASSIGN_OR_RETURN(call_graph_, CallGraph::Build(module));
+ call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
[this, sequence](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {