aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_scheduling_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc248
1 files changed, 248 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 639c20ad8e..930801288a 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -415,5 +417,251 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
.ValueOrDie());
}
+TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) {
+ // Updating the schedule of an unchanged HLO module should not affect the
+ // schedule at all.
+ const string module_str = R"(
+HloModule UpdateScheduleUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+ std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second;
+
+ EXPECT_EQ(entry_schedule.size(), 6);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(entry_schedule, sequence.begin()->second);
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) {
+ // Add some additional instructions to a module and verify the schedule can be
+ // updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithNewInstructions
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ HloComputation* entry = module->entry_computation();
+ const Shape shape = entry->root_instruction()->shape();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
+ entry->set_root_instruction(sub);
+
+ auto in_schedule = [&](const HloInstruction* hlo) {
+ return std::find(sequence.at(entry).begin(), sequence.at(entry).end(),
+ hlo) != sequence.at(entry).end();
+ };
+
+ EXPECT_EQ(sequence.at(entry).size(), 6);
+ EXPECT_FALSE(in_schedule(constant));
+ EXPECT_FALSE(in_schedule(sub));
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(entry).size(), 8);
+ EXPECT_TRUE(in_schedule(constant));
+ EXPECT_TRUE(in_schedule(sub));
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) {
+ // Add and delete some instructions from a module and verify that the schedule
+ // can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithAddedAndDeletedInstruction
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ // Set the entry root to some expression containing just a parameter and a
+ // constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* new_root = entry->AddInstruction(
+ HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
+ constant, entry->parameter_instruction(0)));
+ entry->set_root_instruction(new_root);
+
+ // DCE should remove everything but the parameters and the newly added code.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(sequence.at(entry).size(), 6);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(entry).size(), 4);
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) {
+ // Completely replace a module with an entirely new set of instructions and
+ // verify that the schedule can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithCompletelyReplacedModule
+
+ENTRY main {
+ a = f32[] constant(42.0)
+ b = f32[] constant(123.0)
+ ROOT sum = f32[] add(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ // Replace the entry computation with the negation of a constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ entry->set_root_instruction(new_root);
+
+ // DCE the old instructions.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(sequence.at(entry).size(), 3);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(entry).size(), 2);
+}
+
+TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) {
+ // Create changes to more than one computation in an HLO module and verify
+ // that the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+ tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
+ id_sequence = ComputeIdSchedule(sequence);
+
+ const HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->operand(0);
+ HloComputation* body = xla_while->while_body();
+ HloComputation* cond = xla_while->while_condition();
+
+ // Negate the root of the cond.
+ cond->set_root_instruction(cond->AddInstruction(
+ HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kNot, cond->root_instruction())));
+
+ // Replace the body with a computation which just passes through its
+ // parameter.
+ body->set_root_instruction(body->parameter_instruction(0));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(sequence.at(body).size(), 7);
+ EXPECT_EQ(sequence.at(cond).size(), 4);
+
+ TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
+ TF_ASSERT_OK(VerifySchedule(*module, sequence));
+
+ EXPECT_EQ(sequence.at(body).size(), 1);
+ EXPECT_EQ(sequence.at(cond).size(), 5);
+}
+
} // namespace
} // namespace xla