aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-04-28 14:51:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 16:12:06 -0700
commit765d97a4168429a730862c9898cc936b445a054c (patch)
tree4f3eb5db989746d8faab6e7b9f474bf433cc3b33 /tensorflow/compiler/xla/service/hlo_rematerialization.cc
parent998baa0f1fa8aee4382be1a00e4ae9ee70a6b194 (diff)
[XLA] Make HLO ordering module-scoped.
Add comparison of ordering of HLO instructions which are in different computations using the call graph. Previously, instructions in different computations were considered unordered. Ordering these instructions improves buffer liveness analysis and may enable better buffer sharing between values in different computations. Change: 154592912
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc12
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 65c61627a1..44293f582e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -930,9 +930,8 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
const HloInstruction* instruction) const {
- TF_ASSIGN_OR_RETURN(const CallGraphNode* node,
- call_graph_->GetNode(instruction->parent()));
- const CallSite* callsite = node->GetCallSite(instruction);
+ const CallSite* callsite =
+ call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
return 0;
}
@@ -981,8 +980,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// instructions which are dead.
int64 net_instructions_added = 0;
- TF_ASSIGN_OR_RETURN(const CallGraphNode* call_graph_node,
- call_graph_->GetNode(computation));
+ const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
// Iterate through all instructions in the sequence. At each instruction
// (program point) if memory_usage exceeds the specified limit then
@@ -1080,7 +1078,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< memory_tracker.memory_usage();
}
- const CallSite* callsite = call_graph_node->GetCallSite(instruction);
+ const CallSite* callsite = call_graph_node.GetCallSite(instruction);
if (callsite != nullptr &&
callsite->context() == CallContext::kSequential &&
memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
@@ -1194,7 +1192,7 @@ StatusOr<bool> HloRematerialization::Run(
}));
// Compute peak memory usage of all computations in the module called in a
// sequential context.
- TF_ASSIGN_OR_RETURN(call_graph_, CallGraph::Build(module));
+ call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
[this, sequence](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {