aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-06-14 17:22:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 17:25:24 -0700
commit7e05b8a1c7fec4852e275e708555a759947270d7 (patch)
treed50789d00d38c0ff5cbf56d4780149760a55d3f0
parent9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (diff)
[TF:XLA] Account for subcomputations in heap simulator during scheduling.
PiperOrigin-RevId: 200646674
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc5
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc52
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h58
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc104
7 files changed, 204 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index cb2e159a38..396ce13e7f 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1101,6 +1101,7 @@ tf_cc_test(
srcs = ["hlo_scheduling_test.cc"],
deps = [
":buffer_value",
+ ":heap_simulator",
":hlo",
":hlo_ordering",
":hlo_scheduling",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 5d3b0cb333..afe4b2e142 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -631,8 +631,9 @@ Status BufferAssignment::ComputeSummaryStats() {
}
}
if (module_sequence.size() == module_->computation_count()) {
- TF_ASSIGN_OR_RETURN(const int64 min_size,
- MinimumMemoryForModule(module_sequence, buffer_size_));
+ TF_ASSIGN_OR_RETURN(
+ const int64 min_size,
+ HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 5dba50a63b..a04aa4069d 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -26,7 +26,8 @@ namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
-StatusOr<int64> MinimumMemoryForModule(
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const LogicalBuffer::SizeFunction& size_function) {
if (module_sequence.empty()) {
@@ -49,15 +50,19 @@ StatusOr<int64> MinimumMemoryForModule(
return result.heap_size;
}
-StatusOr<int64> MinimumMemoryForComputation(
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function));
+ sequence, points_to_analysis, size_function,
+ HeapSimulator::Options(), memory_by_computation));
return result.heap_size;
}
@@ -81,9 +86,11 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn, const Options& options) {
+ const BufferValue::SizeFunction& size_fn, const Options& options,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr);
+ /*module_sequence=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -254,6 +261,12 @@ Status HeapSimulator::RunComputation(
Alloc(buffer, instruction);
}
}
+ // Account for the memory used by subcomputations when estimating the
+ // current heap size.
+ if (memory_by_computation_ != nullptr) {
+ algorithm_->AccountForSubcomputationMemory(instruction,
+ *memory_by_computation_);
+ }
// If the whole module is sequential, we can save memory by running the
// heap-simulation for sub-computations inline. E.g. the buffers for the
@@ -321,12 +334,15 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence)
+ const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation)
: no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence) {
+ module_sequence_(module_sequence),
+ memory_by_computation_(memory_by_computation) {
debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
}
@@ -495,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
+void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {
+ // We only count the memory usage of the largest subcomputation, instead of
+ // adding them all, because subcomputations won't execute in parallel.
+ int64 max_subcomputation_bytes = 0;
+ for (const auto* c : instruction->called_computations()) {
+ auto it = memory_by_computation.find(c);
+ if (it != memory_by_computation.end()) {
+ int64 subcomputation_bytes = it->second;
+ if (subcomputation_bytes > max_subcomputation_bytes) {
+ max_subcomputation_bytes = subcomputation_bytes;
+ }
+ }
+ }
+ max_heap_size_ =
+ std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
+}
+
void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
current_heap_size_ -= size;
}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 3be3bb8e7f..811a6042df 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -34,21 +34,6 @@ limitations under the License.
namespace xla {
-// Returns the minimum memory required to compute an HLO module where all
-// computations have been scheduled (represented by the given module_sequence),
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const LogicalBuffer::SizeFunction& size_function);
-
-// Returns the minimum memory required to compute the given computation,
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function);
-
// Forward declare classes defined below.
class HeapAlgorithm;
@@ -100,6 +85,23 @@ class HeapSimulator {
const BufferValueFlatSet* buffers_to_assign;
};
+ // Returns the minimum memory required to compute an HLO module where all
+ // computations have been scheduled (represented by the given
+ // module_sequence), assuming no fragmentation.
+ static StatusOr<int64> MinimumMemoryForModule(
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const LogicalBuffer::SizeFunction& size_function);
+
+ // Returns the minimum memory required to compute the given computation,
+ // assuming no fragmentation.
+ static StatusOr<int64> MinimumMemoryForComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
+
// Run the heap simulation with the given algorithm, assuming the given
// module_sequence, which must contain a topologically-consistent total
// ordering of all instructions within each computation. The result is invalid
@@ -126,7 +128,9 @@ class HeapSimulator {
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
- const Options& options = Options());
+ const Options& options = Options(),
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
private:
// If 'module_sequence' is non-null, it is used to find kCall and kWhile
@@ -135,7 +139,9 @@ class HeapSimulator {
HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence);
+ const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
~HeapSimulator();
Status RunComputation(
@@ -159,7 +165,13 @@ class HeapSimulator {
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
+ // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+ // set by hlo scheduling. Then, in RunComputation, we check both in order to
+ // handle subcomputations. It would be good to unify the handling of
+ // subcomputations, but it's not clear how.
const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation_;
// In addition to Alloc and Free, the heap simulator exposes a concept of
// buffer sharing. When ShareBuffer is called, instead of allocating new
@@ -204,6 +216,11 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
+ virtual void AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {}
+
// Free de-allocates a previously allocated buffer.
virtual void Free(const BufferValue* buffer, int64 size) = 0;
@@ -222,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
~NoFragmentationStatsHeap() override = default;
void Alloc(const BufferValue* buffer, int64 size) override;
+
+ void AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) override;
+
void Free(const BufferValue* buffer, int64 size) override;
+
Result Finish() override;
private:
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 309ab85f78..93d7a14125 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -89,7 +89,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
cond_lt};
module_sequence[body_computation] = {body_param};
module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie());
+ EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
+ .ValueOrDie());
}
const char kAlloc[] = "Alloc";
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index b14ade3549..641b9ecec9 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -375,7 +375,7 @@ int64 SumLogicalBufferSizes(
return size;
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleComputationsInModule(
+StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -498,29 +498,29 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
std::vector<const HloInstruction*> list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 list_memory,
- MinimumMemoryForComputation(computation, list_sequence,
- points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(const int64 list_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, list_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis,
size_function, memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 dfs_memory,
- MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
- size_function));
+ TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, dfs_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> post_order_sequence,
PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 post_order_memory,
- MinimumMemoryForComputation(computation, post_order_sequence,
- points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, post_order_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
@@ -551,12 +551,13 @@ StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
- ScheduleComputationsInModule(
+ ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
- MinimumMemoryForComputation(*computation, one_computation_sequence,
- *points_to_analysis, size_function)
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, one_computation_sequence, *points_to_analysis,
+ size_function, &memory_by_computation)
.ValueOrDie();
sequence[computation] = std::move(one_computation_sequence);
}
@@ -571,8 +572,8 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map;
- return ScheduleComputationsInModule(computation, *points_to_analysis,
- size_function, nullptr, empty_map);
+ return ScheduleComputationHelper(computation, *points_to_analysis,
+ size_function, nullptr, empty_map);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 6f1b1215d3..73f22f81f4 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -144,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
// ROOT %subtract = f32[4]{0} subtract(
// f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
// }
- // %SubcomputationsNotAccounted () -> f32[2,4] {
+ // %ListAccountsForSubcomputations () -> f32[2,4] {
// %constant.3 = f32[2,4]{1,0} constant(
// f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
// %transpose = f32[2,4]{1,0} transpose(
@@ -210,16 +211,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(
- *module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- },
- ListMemoryScheduler));
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ auto entry_computation = module->entry_computation();
+ EXPECT_EQ(entry_computation->instruction_count(),
+ sequence.at(entry_computation).size());
SequentialHloOrdering ordering(module.get(), sequence);
// This schedule is an example of List's greedy heuristics being suboptimal.
// The while_loop is more expensive than transpose, so it would have been
@@ -228,6 +229,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast));
EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 17;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator doesn't account for subcomputations
+ EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn)
+ .ValueOrDie());
+ // HeapSimulator accounts for subcomputations. The max mem doesn't change
+ // because the while body isn't live during the peak.
+ EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
}
TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
@@ -325,5 +344,70 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
}
+TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
+ auto module = CreateNewModule();
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+ // param != 0
+ // Needs 17 bytes
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+ HloInstruction* zero_vector = cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+ auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+ // param - 1
+ // Needs 16 bytes
+ auto body_builder = HloComputation::Builder("WhileBody");
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "body_param"));
+ HloInstruction* one_vector = body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, body_param, one_vector));
+ auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* while_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ // Creates 16 bytes, ignoring subcomputations
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ r1f32, cond_computation, body_computation, while_init));
+
+ module->AddEntryComputation(builder.Build());
+
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ // Verify that all instructions are in the sequence.
+ auto entry_computation = module->entry_computation();
+ EXPECT_EQ(entry_computation->instruction_count(),
+ sequence.at(entry_computation).size());
+
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 17;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator doesn't account for subcomputations
+ EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn)
+ .ValueOrDie());
+ // HeapSimulator accounts for subcomputations
+ EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
+}
+
} // namespace
} // namespace xla