aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc79
1 files changed, 38 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83cb113bfb..f7e82fb1f8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
@@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase {
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
- HloModule* module,
- HloSchedule* schedule) {
+ HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
- return HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- schedule, /*sizes=*/nullptr);
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ DefaultMemoryScheduler);
+ TF_EXPECT_OK(scheduler.Run(module).status());
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr);
+ return remat.Run(module);
}
// Various shapes used in the canned computations.
@@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024, module));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024, module));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024, module));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024, module));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &schedule));
+ bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {