aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc43
1 files changed, 21 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 576c5ff7a4..00a25db467 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
- cond_lt};
- module_sequence[body_computation] = {body_param};
- module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
- .ValueOrDie());
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(cond_computation,
+ {cond_param, cond_iter, cond_data, cond_lt});
+ schedule.set_sequence(body_computation, {body_param});
+ schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(
+ 56,
+ HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
}
const char kAlloc[] = "Alloc";
@@ -149,10 +153,11 @@ class HeapSimulatorTracker {
auto zero_size = [](const BufferValue& buffer) { return 0; };
auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(
- std::move(algorithm), *module_->entry_computation(),
- instruction_sequence, *points_to_analysis_, zero_size)
- .ConsumeValueOrDie();
+ result_ =
+ HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
+ HloInstructionSequence(instruction_sequence),
+ *points_to_analysis_, zero_size)
+ .ConsumeValueOrDie();
}
explicit HeapSimulatorTracker(const string& name) {
@@ -168,11 +173,12 @@ class HeapSimulatorTracker {
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
// Construct the module sequence grouped by computation.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(module_.get());
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
- module_sequence[instruction->parent()].push_back(instruction);
+ schedule.GetOrCreateSequence(instruction->parent())
+ .push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i;
}
@@ -185,8 +191,8 @@ class HeapSimulatorTracker {
};
auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(std::move(algorithm), *module_,
- module_sequence, *points_to_analysis_, size_fn)
+ result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
+ *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
}
@@ -353,13 +359,6 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) {
(neg_buffer == output_buffer_1));
}
-PrecisionConfigProto DefaultPrecisionConfig(int operands) {
- PrecisionConfigProto precision_config;
- precision_config.mutable_operand_precision()->Resize(
- operands, PrecisionConfigProto::DEFAULT);
- return precision_config;
-}
-
TEST_F(HeapSimulatorTest, MultiplyDot) {
auto builder = HloComputation::Builder(TestName());
auto paramA = builder.AddInstruction(