aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator.cc
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-06-14 17:22:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 17:25:24 -0700
commit7e05b8a1c7fec4852e275e708555a759947270d7 (patch)
treed50789d00d38c0ff5cbf56d4780149760a55d3f0 /tensorflow/compiler/xla/service/heap_simulator.cc
parent9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (diff)
[TF:XLA] Account for subcomputations in heap simulator during scheduling.
PiperOrigin-RevId: 200646674
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc52
1 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 5dba50a63b..a04aa4069d 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -26,7 +26,8 @@ namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
-StatusOr<int64> MinimumMemoryForModule(
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const LogicalBuffer::SizeFunction& size_function) {
if (module_sequence.empty()) {
@@ -49,15 +50,19 @@ StatusOr<int64> MinimumMemoryForModule(
return result.heap_size;
}
-StatusOr<int64> MinimumMemoryForComputation(
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function));
+ sequence, points_to_analysis, size_function,
+ HeapSimulator::Options(), memory_by_computation));
return result.heap_size;
}
@@ -81,9 +86,11 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn, const Options& options) {
+ const BufferValue::SizeFunction& size_fn, const Options& options,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr);
+ /*module_sequence=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -254,6 +261,12 @@ Status HeapSimulator::RunComputation(
Alloc(buffer, instruction);
}
}
+ // Account for the memory used by subcomputations when estimating the
+ // current heap size.
+ if (memory_by_computation_ != nullptr) {
+ algorithm_->AccountForSubcomputationMemory(instruction,
+ *memory_by_computation_);
+ }
// If the whole module is sequential, we can save memory by running the
// heap-simulation for sub-computations inline. E.g. the buffers for the
@@ -321,12 +334,15 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence)
+ const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation)
: no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence) {
+ module_sequence_(module_sequence),
+ memory_by_computation_(memory_by_computation) {
debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
}
@@ -495,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
+void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {
+ // We only count the memory usage of the largest subcomputation, instead of
+ // adding them all, because subcomputations won't execute in parallel.
+ int64 max_subcomputation_bytes = 0;
+ for (const auto* c : instruction->called_computations()) {
+ auto it = memory_by_computation.find(c);
+ if (it != memory_by_computation.end()) {
+ int64 subcomputation_bytes = it->second;
+ if (subcomputation_bytes > max_subcomputation_bytes) {
+ max_subcomputation_bytes = subcomputation_bytes;
+ }
+ }
+ }
+ max_heap_size_ =
+ std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
+}
+
void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
current_heap_size_ -= size;
}