diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization_test.cc | 79 |
1 files changed, 38 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83cb113bfb..f7e82fb1f8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase { } StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module, - HloSchedule* schedule) { + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - schedule, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &schedule)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { |