diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator_test.cc | 43 |
1 files changed, 21 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 576c5ff7a4..00a25db467 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) - .ValueOrDie()); + HloSchedule schedule(module.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -149,10 +153,11 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( absl::make_unique<HeapCallRecorder>(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -168,11 +173,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_.get()); tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -185,8 +191,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( absl::make_unique<HeapCallRecorder>(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } @@ -353,13 +359,6 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) { (neg_buffer == output_buffer_1)); } -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( |