aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-05 17:17:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 17:22:22 -0700
commit6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 (patch)
tree1afd3dff710c4f63bae267807435abdcec784edb
parent017599d0a1fa7a7227a43649db67e96311033a4e (diff)
Rollforward of cl/211656888 after fixing failing unit test.
*** Original change description *** Add HloSchedule class representing a sequential order of an HloModule. Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encap... *** PiperOrigin-RevId: 211726890
-rw-r--r--tensorflow/compiler/xla/service/BUILD48
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc28
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc98
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc56
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h4
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc43
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h48
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc36
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc86
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc291
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h151
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc341
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc230
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h54
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc343
27 files changed, 1325 insertions, 905 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 64141ed191..ab86dce510 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -989,6 +989,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
@@ -1036,6 +1037,7 @@ tf_cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
+ ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -1049,6 +1051,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
@@ -1062,6 +1065,7 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
+ ":hlo_schedule",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1082,6 +1086,7 @@ tf_cc_test(
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
+ ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1089,6 +1094,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
],
)
@@ -1102,6 +1108,7 @@ cc_library(
":hlo",
":hlo_ordering",
":hlo_proto",
+ ":hlo_schedule",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1125,6 +1132,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
@@ -1170,6 +1178,43 @@ cc_library(
)
cc_library(
+ name = "hlo_schedule",
+ srcs = ["hlo_schedule.cc"],
+ hdrs = ["hlo_schedule.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_schedule_test",
+ srcs = ["hlo_schedule_test.cc"],
+ deps = [
+ ":heap_simulator",
+ ":hlo",
+ ":hlo_dce",
+ ":hlo_ordering",
+ ":hlo_parser",
+ ":hlo_schedule",
+ ":hlo_scheduling",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
+ ],
+)
+
+cc_library(
name = "hlo_scheduling",
srcs = ["hlo_scheduling.cc"],
hdrs = ["hlo_scheduling.h"],
@@ -1177,6 +1222,7 @@ cc_library(
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_schedule",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1205,6 +1251,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2366,6 +2413,7 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
+ ":hlo_schedule",
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 8b8c6bfd26..0f0af57626 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -617,18 +617,24 @@ Status BufferAssignment::ComputeSummaryStats() {
}
// Only compute total fragmentation if all computations have schedules.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(module_);
+ bool schedule_complete = true;
for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- liveness_->hlo_ordering().SequentialOrder(*computation);
- if (sequence != nullptr) {
- module_sequence.emplace(computation, *sequence);
+ if (!computation->IsFusionComputation()) {
+ const std::vector<const HloInstruction*>* sequence =
+ liveness_->hlo_ordering().SequentialOrder(*computation);
+ if (sequence == nullptr) {
+ schedule_complete = false;
+ } else {
+ schedule.set_sequence(computation, *sequence);
+ }
}
}
- if (module_sequence.size() == module_->computation_count()) {
+ if (schedule_complete) {
+ TF_RETURN_IF_ERROR(schedule.Verify());
TF_ASSIGN_OR_RETURN(
const int64 min_size,
- HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
+ HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
}
@@ -1064,7 +1070,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// since buffers for kCall, kWhile, and kConditional sub-computations are
// only live for the duration of their calling instructions.
VLOG(1) << "Running whole-module heap simulation";
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(&assignment->module());
FlatSet<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
@@ -1072,7 +1078,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
- module_sequence[computation] = *instruction_sequence;
+ schedule.set_sequence(computation, *instruction_sequence);
all_buffers_to_assign.insert(buffers_to_assign.begin(),
buffers_to_assign.end());
}
@@ -1090,7 +1096,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const HeapSimulator::Result result,
HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<LazyBestFitHeap>(alignment)),
- assignment->module(), module_sequence,
+ assignment->module(), schedule,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
@@ -1121,7 +1127,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
HeapSimulator::Run(
absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<LazyBestFitHeap>(alignment)),
- *computation, *instruction_sequence,
+ *computation, HloInstructionSequence(*instruction_sequence),
assignment->points_to_analysis(), assignment->buffer_size_,
options));
AssignBuffersFromHeapSimulator(result, assignment,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 56bd67fb55..5a231c173d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -120,14 +122,10 @@ class BufferAssignmentTest : public HloVerifiedTestBase {
HloModule* module,
absl::Span<const HloInstruction* const> instruction_sequence,
int64 alignment = 1) {
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[module->entry_computation()] =
- std::vector<const HloInstruction*>(instruction_sequence.begin(),
- instruction_sequence.end());
+ HloSchedule schedule(module);
+ schedule.set_sequence(module->entry_computation(), instruction_sequence);
return BufferAssigner::Run(
- module,
- absl::make_unique<SequentialHloOrdering>(module,
- module_sequence),
+ module, absl::make_unique<SequentialHloOrdering>(schedule),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -1785,11 +1783,10 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
- auto sequence =
- ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+ HloSchedule schedule =
+ ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
- module,
- absl::make_unique<SequentialHloOrdering>(module, sequence),
+ module, absl::make_unique<SequentialHloOrdering>(schedule),
ByteSizeOf,
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -2096,17 +2093,25 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// Create a sequential order among all the instructions in the entry
// computation, since the issue this test stresses depends on the order the
// nodes are traversed during BufferAssignment.
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence[module->entry_computation()] = {
- token, infeed, infeed_data, while0, while1, zero, add, while2, tuple};
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+ schedule.set_sequence(
+ module->entry_computation(),
+ {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
+ TF_ASSERT_OK(schedule.Verify());
+
TF_ASSERT_OK_AND_ASSIGN(
auto assignment,
- BufferAssigner::Run(
- module, absl::make_unique<SequentialHloOrdering>(module, sequence),
- backend().compiler()->BufferSizeBytesFunction(),
- [](LogicalBuffer::Color) { return 1; },
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ backend().compiler()->BufferSizeBytesFunction(),
+ [](LogicalBuffer::Color) { return 1; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// The result tuple elements must be assigned with different buffers.
TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
@@ -2263,29 +2268,6 @@ ENTRY Main {
GetAllocation(*buffers, param0, {1, 1}));
}
-static bool IsPostOrderTraversal(
- const std::vector<const HloInstruction*>& sequence) {
- tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
- auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
- return seen_so_far.count(instruction) == 0;
- };
-
- for (auto instruction : sequence) {
- if (std::any_of(instruction->operands().begin(),
- instruction->operands().end(), has_not_been_seen_yet) ||
- std::any_of(instruction->control_predecessors().begin(),
- instruction->control_predecessors().end(),
- has_not_been_seen_yet)) {
- return false; // Not a post order.
- }
- if (!seen_so_far.insert(instruction).second) {
- return false; // Not a "traversal".
- }
- }
-
- return true;
-}
-
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
@@ -2340,27 +2322,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
RunCopyInsertion(module);
- auto sequence =
- ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+ HloSchedule schedule =
+ ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
- // To trigger b/38494731, we want a specific Hlo sequence for the
+ // To trigger b/38494731, we want a specific Hlo schedule for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
- sequence[module->entry_computation()] = {
- input1, weights1, one, output1, while1->operand(0), while1,
- input0, weights0, zero, output0, while0->operand(0), while0,
- gte0, gte1, root_add};
+ schedule.set_sequence(module->entry_computation(),
+ {input1, weights1, one, output1, while1->operand(0),
+ while1, input0, weights0, zero, output0,
+ while0->operand(0), while0, gte0, gte1, root_add});
- // If this ASSERT_TRUE fails, we constructed a bogus sequence above
- // and this test itself is buggy.
- ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
+ // If this ASSERT fails, we constructed a bogus sequence above and this test
+ // itself is buggy.
+ TF_ASSERT_OK(schedule.Verify());
auto assignment =
- BufferAssigner::Run(
- module, absl::make_unique<SequentialHloOrdering>(module, sequence),
- ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true)
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true)
.ConsumeValueOrDie();
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 26e26e316d..414bfe7999 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
auto module = CreateNewModule();
HloComputation* entry = module->AddEntryComputation(builder.Build());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- auto liveness = BufferLiveness::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), sequence))
- .ConsumeValueOrDie();
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
+ .ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- SequentialHloOrdering::HloModuleSequence module_sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- module_sequence.emplace(computation, order);
- auto liveness = BufferLiveness::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(computation, {param, negate, exp, add});
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
+ .ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build(add));
- SequentialHloOrdering::HloModuleSequence module_sequence;
- std::vector<const HloInstruction*> order = {param, add, recv,
- recv_done, send, send_done};
- module_sequence.emplace(computation, order);
- auto liveness = BufferLiveness::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(computation,
+ {param, add, token, recv, recv_done, send, send_done});
+ TF_ASSERT_OK(schedule.Verify());
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
+ .ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Check the root instruction (add) buffer interferes with the recv buffer.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 796f36510e..e7b6075994 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -584,16 +584,14 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
- DFSMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
// Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), module_sequence),
+ absl::make_unique<SequentialHloOrdering>(schedule),
BufferSizeBytesFunction(), memory_alignment,
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true));
@@ -627,9 +625,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
}
TF_RETURN_IF_ERROR(
ir_emitter
- .EmitComputation(embedded_computation, embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
+ .EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &schedule.sequence(embedded_computation).instructions())
.status());
}
string function_name_prefix = entry_computation->name().empty()
@@ -637,9 +636,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
: entry_computation->name();
TF_ASSIGN_OR_RETURN(
llvm::Function * entry_function,
- ir_emitter.EmitComputation(entry_computation, function_name_prefix,
- /*is_top_level_computation=*/true,
- &module_sequence.at(entry_computation)));
+ ir_emitter.EmitComputation(
+ entry_computation, function_name_prefix,
+ /*is_top_level_computation=*/true,
+ &schedule.sequence(entry_computation).instructions()));
string function_name = [&]() {
llvm::SmallVector<char, 40> function_name_vector;
@@ -771,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
- TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- ScheduleComputationsInModule(*module, BufferSizeBytesFunction()));
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, BufferSizeBytesFunction()));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module,
- absl::make_unique<SequentialHloOrdering>(module, module_sequence),
- BufferSizeBytesFunction(), memory_alignment,
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -824,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
}
TF_RETURN_IF_ERROR(
ir_emitter
- .EmitComputation(embedded_computation,
- embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
+ .EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &schedule.sequence(embedded_computation).instructions())
.status());
}
const string& entry_point_name = options.entry_point_name();
- TF_ASSIGN_OR_RETURN(
- llvm::Function * entry_function,
- ir_emitter.EmitComputation(computation, entry_point_name,
- /*is_top_level_computation=*/true,
- &module_sequence.at(computation)));
+ TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
+ ir_emitter.EmitComputation(
+ computation, entry_point_name,
+ /*is_top_level_computation=*/true,
+ &schedule.sequence(computation).instructions()));
CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index e5cf15c686..df8c2a636b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -110,7 +110,7 @@ IrEmitter::IrEmitter(
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
- std::vector<const HloInstruction*>* instruction_order) {
+ const std::vector<const HloInstruction*>* instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 58a333b8fb..3df99464ba 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Function*> EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
- std::vector<const HloInstruction*>* instruction_order);
+ const std::vector<const HloInstruction*>* instruction_order);
llvm::IRBuilder<>* b() { return &b_; }
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a68b7a1bef..13ccff35f8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -813,6 +813,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
+ "//tensorflow/compiler/xla/service:hlo_schedule",
"//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index 743035a84e..ea9376e101 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
@@ -198,11 +199,12 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
// All kernels are launched on a single stream, so there's no loss of
// concurrency by optimizing for minimal memory usage.
TF_ASSIGN_OR_RETURN(
- schedule->thunk_launch_order_,
- ScheduleOneComputation(
+ HloInstructionSequence sequence,
+ ScheduleComputation(
*entry_computation, [pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
}));
+ schedule->thunk_launch_order_ = sequence.instructions();
} else {
// BFS tends to increase concurrency, but also increases memory usage.
BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
index 30a0e7cecd..07a7fc67aa 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
@@ -33,7 +33,9 @@ namespace gpu {
// launches, because thunks may be scheduled onto concurrent streams. This
// schedule is used by BufferAssigner to determine buffer liveness (i.e. to
// minimize allocations), and also by ThunkSchedule to determine the thunk
-// launch order.
+// launch order. This class differs from xla::HloSchedule in that HloSchedule
+// represents a total order of all instructions in the module for backends which
+// execute HLO instructions strictly sequentially.
class GpuHloSchedule {
public:
// Constructs an GpuHloSchedule for the given module, based on the given
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 38c3982ebf..e0f3a7e0e2 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function) {
- if (module_sequence.empty()) {
+ if (schedule.empty()) {
return 0;
}
- const HloModule* module = module_sequence.begin()->first->parent();
+ const HloModule* module = schedule.module();
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
@@ -47,14 +47,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
- module_sequence, *points_to_analysis, size_function));
+ schedule, *points_to_analysis, size_function));
return result.heap_size;
}
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
@@ -71,13 +70,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options) {
- HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
+ HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
const HloComputation* entry_computation = module.entry_computation();
- const std::vector<const HloInstruction*>& instruction_sequence =
- FindOrDie(module_sequence, entry_computation);
+ const HloInstructionSequence& instruction_sequence =
+ schedule.sequence(entry_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(
*entry_computation, instruction_sequence, points_to_analysis));
return heap.Finish();
@@ -86,13 +85,13 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr, memory_by_computation);
+ /*schedule=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -102,7 +101,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// 'instruction_sequence'.
Status HeapSimulator::RunComputation(
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis) {
VLOG(3) << "Computation:\n" << computation.ToString();
// The goal here is to minimize memory usage, assuming the given sequential
@@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation(
// set of instructions that need to be visited contains all users of all
// aliases, that is, all users of all instructions that have the buffer
// contained in their points-to set.
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const PointsToSet& points_to =
points_to_analysis.GetPointsToSet(instruction);
const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
@@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation(
std::vector<const BufferValue*> dead_buffers_to_free;
std::vector<const BufferValue*> operand_buffers_to_free;
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const TuplePointsToAnalysis::BufferDefinitionVector&
buffers_defined_by_instruction =
points_to_analysis.GetBuffersDefinedByInstruction(instruction);
@@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation(
// The order that the sub-computations are simulated does not affect
// correctness; since the whole module has been scheduled, we know that the
// sub-computations will never be run concurrently.
- if (module_sequence_ != nullptr) {
+ if (schedule_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kConditional ||
instruction->opcode() == HloOpcode::kWhile) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
- const std::vector<const HloInstruction*>& called_sequence =
- FindOrDie(*module_sequence_, called_computation);
+ const HloInstructionSequence& called_sequence =
+ schedule_->sequence(called_computation);
TF_RETURN_IF_ERROR(RunComputation(
*called_computation, called_sequence, points_to_analysis));
}
@@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const HloSchedule* schedule,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation)
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence),
+ schedule_(schedule),
memory_by_computation_(memory_by_computation) {
- debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
+ debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
}
HeapSimulator::~HeapSimulator() {}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index af05bedee7..ffbf947d5a 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -88,23 +89,22 @@ class HeapSimulator {
// Returns the minimum memory required to compute an HLO module where all
// computations have been scheduled (represented by the given
- // module_sequence), assuming no fragmentation.
+ // schedule), assuming no fragmentation.
static StatusOr<int64> MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function);
// Returns the minimum memory required to compute the given computation,
// assuming no fragmentation.
static StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation = nullptr);
// Run the heap simulation with the given algorithm, assuming the given
- // module_sequence, which must contain a topologically-consistent total
+ // schedule, which must contain a topologically-consistent total
// ordering of all instructions within each computation. The result is invalid
// if instructions are not run in exactly this sequence.
//
@@ -112,12 +112,12 @@ class HeapSimulator {
// to running on a per-computation basis, since we can re-use buffer space for
// called sub-computations.
//
- static StatusOr<Result> Run(
- std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn,
- const Options& options = Options());
+ static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm,
+ const HloModule& module,
+ const HloSchedule& schedule,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options = Options());
// Same as above, but runs on a single computation. The 'instruction_sequence'
// must contain a topologically-consistent total ordering of all instructions
@@ -126,7 +126,7 @@ class HeapSimulator {
static StatusOr<Result> Run(
std::unique_ptr<HeapAlgorithm> algorithm,
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
const Options& options = Options(),
@@ -134,21 +134,19 @@ class HeapSimulator {
memory_by_computation = nullptr);
private:
- // If 'module_sequence' is non-null, it is used to find kCall and kWhile
+ // If 'schedule' is non-null, it is used to find kCall and kWhile
// sub-computations, and the heap simulation for those sub-computations will
// be run recursively. I.e. the simulation is run over the whole module.
- HeapSimulator(
- std::unique_ptr<HeapAlgorithm> algorithm,
- const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
- memory_by_computation = nullptr);
+ HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options, const HloSchedule* schedule = nullptr,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
~HeapSimulator();
- Status RunComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
- const TuplePointsToAnalysis& points_to_analysis);
+ Status RunComputation(const HloComputation& computation,
+ const HloInstructionSequence& instruction_sequence,
+ const TuplePointsToAnalysis& points_to_analysis);
bool IgnoreBuffer(const BufferValue* buffer) const;
void Alloc(const BufferValue* buffer, const HloInstruction* instruction);
@@ -169,11 +167,11 @@ class HeapSimulator {
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
- // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+ // schedule_ is set by buffer assignment, and memory_by_computation_ is
// set by hlo scheduling. Then, in RunComputation, we check both in order to
// handle subcomputations. It would be good to unify the handling of
// subcomputations, but it's not clear how.
- const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+ const HloSchedule* schedule_;
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation_;
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 7ad8a107e1..00a25db467 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
- cond_lt};
- module_sequence[body_computation] = {body_param};
- module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
- .ValueOrDie());
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(cond_computation,
+ {cond_param, cond_iter, cond_data, cond_lt});
+ schedule.set_sequence(body_computation, {body_param});
+ schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(
+ 56,
+ HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
}
const char kAlloc[] = "Alloc";
@@ -149,10 +153,11 @@ class HeapSimulatorTracker {
auto zero_size = [](const BufferValue& buffer) { return 0; };
auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(
- std::move(algorithm), *module_->entry_computation(),
- instruction_sequence, *points_to_analysis_, zero_size)
- .ConsumeValueOrDie();
+ result_ =
+ HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
+ HloInstructionSequence(instruction_sequence),
+ *points_to_analysis_, zero_size)
+ .ConsumeValueOrDie();
}
explicit HeapSimulatorTracker(const string& name) {
@@ -168,11 +173,12 @@ class HeapSimulatorTracker {
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
// Construct the module sequence grouped by computation.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(module_.get());
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
- module_sequence[instruction->parent()].push_back(instruction);
+ schedule.GetOrCreateSequence(instruction->parent())
+ .push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i;
}
@@ -185,8 +191,8 @@ class HeapSimulatorTracker {
};
auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(std::move(algorithm), *module_,
- module_sequence, *points_to_analysis_, size_fn)
+ result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
+ *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
}
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index 54abe3345d..0cd0ab36fc 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -885,18 +885,20 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
// For a sequential order, if there is interference iff the negate is after
// the while.
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence[body] = {body_param, body_root};
- sequence[condition] = {cond_param, cond_root};
+ HloSchedule schedule(module_);
+ schedule.set_sequence(body, {body_param, body_root});
+ schedule.set_sequence(condition, {cond_param, cond_root});
{
- sequence[entry] = {init, xla_while, negate, entry_root};
- SequentialHloOrdering ordering(module_, sequence);
+ schedule.set_sequence(entry, {init, xla_while, negate, entry_root});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
{
- sequence[entry] = {init, negate, xla_while, entry_root};
- SequentialHloOrdering ordering(module_, sequence);
+ schedule.set_sequence(entry, {init, negate, xla_while, entry_root});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 72b236801a..510d6360a1 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
bool ssa_form = GetParam();
RunAnalysis(ssa_form);
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param, xla_while}});
- sequence.insert({condition, {cond_param, cond_constant}});
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, xla_while});
+ schedule.set_sequence(condition, {cond_param, cond_constant});
// Construct the order such that 'constant' and its use 'exp' are before
// body_param.
- sequence.insert({body, {constant, exp, body_param, add}});
+ schedule.set_sequence(
+ body, {constant, exp, body_param, add, dead_constant, dead_negate});
+ TF_ASSERT_OK(schedule.Verify());
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
// 'add' is live out of the body and will interfere with an later instructions
// such as 'dead_constant' and 'dead_negate'.
@@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- sequence.emplace(entry, order);
-
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, negate, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 0581d5c404..2105f7a349 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore(
VLOG(4) << a << " not defined before " << b;
return false;
}
+
+ if (a.live_out_of_module()) {
+ VLOG(4) << a << " is live out of module and defined before " << b;
+ return false;
+ }
+
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
@@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore(
return false;
}
}
+
+ if (a.instruction()->parent() == b.instruction()->parent()) {
+ for (const HloPosition& position : a.positions()) {
+ if (position.instruction ==
+ a.instruction()->parent()->root_instruction()) {
+ VLOG(4) << a << " is live out of computation and defined before " << b
+ << " which is in same computation";
+ return false;
+ }
+ }
+ }
+
return true;
}
@@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const {
return ToStringHelper("DependencyHloOrdering");
}
-SequentialHloOrdering::SequentialHloOrdering(
- const HloModule* module, const HloModuleSequence& module_sequence)
- : HloOrdering(module), module_sequence_(module_sequence) {
+SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
+ : HloOrdering(schedule.module()), schedule_(schedule) {
+ Initialize();
+}
+
+SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
+ : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
+ Initialize();
+}
+
+void SequentialHloOrdering::Initialize() {
// Create a map from instruction to its order position.
- for (auto computation_order : module_sequence_) {
- const std::vector<const HloInstruction*>& order = computation_order.second;
+ TF_DCHECK_OK(schedule_.Verify());
+ for (const auto& computation_sequence : schedule_.sequences()) {
+ const std::vector<const HloInstruction*>& order =
+ computation_sequence.second.instructions();
for (int i = 0; i < order.size(); ++i) {
- DCHECK_EQ(0, order_position_.count(order[i]));
- order_position_.emplace(order[i], i);
+ InsertOrDie(&order_position_, order[i], i);
}
}
}
@@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
const std::vector<const HloInstruction*>*
SequentialHloOrdering::SequentialOrder(
const HloComputation& computation) const {
- auto find_it = module_sequence_.find(&computation);
- return find_it == module_sequence_.end() ? nullptr : &find_it->second;
+ return schedule_.is_computation_scheduled(&computation)
+ ? &schedule_.sequence(&computation).instructions()
+ : nullptr;
}
string SequentialHloOrdering::ToString() const {
- std::vector<string> pieces;
- pieces.push_back("SequentialHloOrdering");
- for (auto* computation : module_->computations()) {
- pieces.push_back(
- absl::StrFormat("computation %s order:", computation->name()));
- // Gather all instructions in the module sequence for this computation and
- // sort them by their position.
- std::vector<const HloInstruction*> instructions;
- for (auto& instruction_position : order_position_) {
- const HloInstruction* instruction = instruction_position.first;
- if (instruction->parent() == computation) {
- instructions.push_back(instruction);
- }
- }
- std::sort(instructions.begin(), instructions.end(),
- [this](const HloInstruction* a, const HloInstruction* b) {
- return order_position_.at(a) < order_position_.at(b);
- });
- for (auto instruction : instructions) {
- pieces.push_back(absl::StrFormat(" %s", instruction->name()));
- }
- }
- return absl::StrJoin(pieces, "\n");
-}
-
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence) {
- for (auto computation_pair : module_sequence) {
- const HloComputation* computation = computation_pair.first;
- const std::vector<const HloInstruction*>& computation_sequence =
- computation_pair.second;
- out << "Computation " << computation->name() << ":\n";
- for (auto* instruction : computation_sequence) {
- out << " " << instruction->name() << "\n";
- }
- }
- return out;
+ return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index 985f3fa64d..b21071c4b2 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -183,17 +184,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering {
// interference is reduced relative to DependencyHloOrdering.
class SequentialHloOrdering : public HloOrdering {
public:
- // TODO(dimvar): HloModuleSequence is not a good name because it sounds like
- // a sequence of modules, instead of a map of schedules for all computations
- // in a module. We should change it at some point.
- //
- // A sequence of instructions for each computation in the module.
- using HloModuleSequence =
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::vector<const HloInstruction*>>;
-
- SequentialHloOrdering(const HloModule* module,
- const HloModuleSequence& module_sequence);
+ SequentialHloOrdering(const HloSchedule& schedule);
+ SequentialHloOrdering(HloSchedule&& schedule);
~SequentialHloOrdering() override = default;
// Returns the sequential instruction order for the given computation.
@@ -203,10 +195,12 @@ class SequentialHloOrdering : public HloOrdering {
string ToString() const override;
protected:
+ void Initialize();
+
bool ExecutesBeforeInSameComputation(const HloInstruction* a,
const HloInstruction* b) const override;
- const HloModuleSequence module_sequence_;
+ const HloSchedule schedule_;
// The position of every instruction in the HLO module in its respective
// computation sequence (a value of zero indicates the instruction is first in
@@ -217,10 +211,6 @@ class SequentialHloOrdering : public HloOrdering {
tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
};
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 126d3a2d9c..6b6005e7a5 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -376,5 +378,104 @@ ENTRY root {
dataflow->GetValueDefinedAt(add_3)));
}
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of the module should interfere with values
+ // defined after the root instruction. That is:
+ //
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* entry =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param, root, dead});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of a computation should interfere with values
+ // defined after the root instruction of the computation. That is:
+ //
+ // subcomputation:
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // entry computation:
+ // %c = constant(42.0)
+ // ROOT %call = call({%c}), subcomputation
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto subbuilder = HloComputation::Builder(TestName() + ".sub");
+ HloInstruction* param = subbuilder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = subbuilder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = subbuilder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* subcomputation = module->AddEmbeddedComputation(
+ subbuilder.Build(/*root_instruction=*/root));
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
+ HloComputation* entry = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(subcomputation, {param, root, dead});
+ schedule.set_sequence(entry, {c, call});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c9629926ea..0a0a6a323e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -962,8 +962,7 @@ StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
}
StatusOr<bool> HloRematerialization::RematerializeComputation(
- HloComputation* computation,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ HloComputation* computation, HloSchedule* schedule,
int64 memory_limit_bytes) {
VLOG(1) << "Rematerializing computation " << computation->name()
<< " with limit " << HumanReadableNumBytes(memory_limit_bytes);
@@ -971,7 +970,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
CHECK(!ContainsKey(rematerialized_computations_, computation));
- InstructionList instruction_list(sequence->at(computation));
+ InstructionList instruction_list(
+ schedule->sequence(computation).instructions());
MemoryUsageTracker memory_tracker(computation, size_function_,
*points_to_analysis_, instruction_list);
bool changed = false;
@@ -1145,7 +1145,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
0, memory_limit_bytes - memory_tracker.memory_usage());
TF_ASSIGN_OR_RETURN(
bool subcomputation_changed,
- RematerializeComputation(called_computation, sequence,
+ RematerializeComputation(called_computation, schedule,
subcomputation_memory_limit_bytes));
changed |= subcomputation_changed;
}
@@ -1179,12 +1179,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
computation_peak_memory_.at(computation) = peak_memory;
// Update order to include rematerialized instructions.
- auto& dst = sequence->at(computation);
- dst.clear();
+ HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
+ sequence.clear();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
- dst.push_back(instruction);
+ sequence.push_back(instruction);
}
rematerialized_computations_.insert(computation);
@@ -1194,20 +1194,21 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(
- HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The sequence is constructed entirely by this method.
- TF_RET_CHECK(sequence->empty());
+StatusOr<bool> HloRematerialization::Run(HloModule* module,
+ HloSchedule* schedule,
+ int64 memory_limit_bytes,
+ RematerializationSizes* sizes,
+ CopyInsertion* copy_insertion) {
+ // The schedule is constructed entirely by this method.
+ TF_RET_CHECK(schedule->empty());
VLOG(1) << "HloRematerialization() with memory limit of "
<< HumanReadableNumBytes(memory_limit_bytes);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial sequence of HLO instructions.
- TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
- *module,
+ // Create initial schedule of HLO instructions.
+ TF_ASSIGN_OR_RETURN(*schedule,
+ ScheduleModule(*module,
[this](const BufferValue& buffer) {
return size_function_(buffer.shape());
},
@@ -1217,16 +1218,7 @@ StatusOr<bool> HloRematerialization::Run(
// ordering from the HLO schedule allows for more copies to be eliminated.
// TODO(b/80249101): Instead of a separate copy elision pass, use the
// ordering from the HLO schedule directly for copy insertion.
-
- // First create a copy of the schedule which contains HloInstruction unique
- // ids instead of HloInstruction*. This is necessary for updating the
- // schedule below.
- // TODO(b/113175018): Remove this when the HLO schedule is self-contained
- // and can update itself.
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(*sequence);
-
- SequentialHloOrdering ordering(module, *sequence);
+ SequentialHloOrdering ordering(*schedule);
TF_RETURN_IF_ERROR(
copy_insertion->RemoveUnnecessaryCopies(ordering, module));
@@ -1241,10 +1233,10 @@ StatusOr<bool> HloRematerialization::Run(
// The passes above can add and remove copies, update the schedule to
// account for these transformations. Newly added instructions will be
// placed ASAP in the schedule.
- TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence));
+ TF_RETURN_IF_ERROR(schedule->Update());
TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
- SequentialHloOrdering(module, *sequence), module));
+ SequentialHloOrdering(*schedule), module));
}
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
@@ -1271,12 +1263,13 @@ StatusOr<bool> HloRematerialization::Run(
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, sequence](const CallGraphNode& node) -> Status {
+ [this, schedule](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
- ComputePeakMemory(node.computation(),
- sequence->at(node.computation())));
+ ComputePeakMemory(
+ node.computation(),
+ schedule->sequence(node.computation()).instructions()));
}
return Status::OK();
},
@@ -1295,7 +1288,7 @@ StatusOr<bool> HloRematerialization::Run(
// Subcomputations called by the entry computation will also be
// rematerialized.
TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), sequence,
+ module->entry_computation(), schedule,
adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
@@ -1305,30 +1298,7 @@ StatusOr<bool> HloRematerialization::Run(
// After DCE, the module sequence may include instructions which no longer
// exist.
- for (const auto* computation : module->MakeNonfusionComputations()) {
- if (sequence->at(computation).size() != computation->instruction_count()) {
- // A size mismatch between the computation instruction count and the size
- // of the ordering of instructions can only be caused by DCE. Rebuild the
- // order by removing the deleted instructions from the order.
- tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set;
- for (const auto& instruction : computation->instructions()) {
- instruction_set.insert(instruction);
- }
- // Move the old order into a temporary vector, then build new order
- // inplace.
- std::vector<const HloInstruction*>& order = sequence->at(computation);
- std::vector<const HloInstruction*> old_order;
- using std::swap;
- swap(order, old_order);
- std::copy_if(old_order.begin(), old_order.end(),
- std::back_inserter(order),
- [&instruction_set](const HloInstruction* instruction) {
- return ContainsKey(instruction_set, instruction);
- });
- TF_RET_CHECK(sequence->at(computation).size() ==
- computation->instruction_count());
- }
- }
+ TF_RETURN_IF_ERROR(schedule->Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1366,11 +1336,10 @@ StatusOr<bool> HloRematerialization::Run(
/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
const HloRematerialization::ShapeSizeFunction& size_function,
int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule,
RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
+ return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes,
copy_insertion);
}
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ec004350a..fa0414b472 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -21,6 +21,7 @@
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -50,7 +51,7 @@ class HloRematerialization {
//
// hlo_module: HLO module to rematerialize instructions in.
//
- // sequence: Should point to an empty HloModuleSequence. Upon return
+ // schedule: Should point to an empty HloSchedule. Upon return
// contains the HLO instruction order which was used for
// rematerialization. This is the order in which HLO instructions should
// be emitted to minimize memory use.
@@ -75,8 +76,8 @@ class HloRematerialization {
static StatusOr<bool> RematerializeAndSchedule(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr);
+ HloSchedule* schedule, RematerializationSizes* sizes,
+ CopyInsertion* copy_insertion = nullptr);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@@ -87,10 +88,9 @@ class HloRematerialization {
// Runs rematerialization on the given module. Returns whether the module was
// changed. memory_limit is the target maximum peak memory usage by the
- // module. sequence should be an empty HloModuleSequence. Upon return sequence
+ // module. schedule should be an empty HloSchedule. Upon return sequence
// contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ StatusOr<bool> Run(HloModule* module, HloSchedule* schedule,
int64 memory_limit, RematerializationSizes* sizes,
CopyInsertion* copy_insertion);
@@ -98,10 +98,9 @@ class HloRematerialization {
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
// and inserted into 'order'.
- StatusOr<bool> RematerializeComputation(
- HloComputation* computation,
- SequentialHloOrdering::HloModuleSequence* sequence,
- int64 computation_memory_limit);
+ StatusOr<bool> RematerializeComputation(HloComputation* computation,
+ HloSchedule* schedule,
+ int64 memory_limit_bytes);
// Computes and returns the peak memory used by the given computation. The
// peak memory is the maximum total size of all live HLO instruction values at
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index ac8c97d380..83cb113bfb 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -141,13 +141,13 @@ class HloRematerializationTest : public HloTestBase {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
- StatusOr<bool> RunHloRematerialization(
- int64 memory_limit_bytes, HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence) {
+ StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+ HloModule* module,
+ HloSchedule* schedule) {
TF_EXPECT_OK(verifier().Run(module).status());
return HloRematerialization::RematerializeAndSchedule(
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- sequence, /*sizes=*/nullptr);
+ schedule, /*sizes=*/nullptr);
}
// Various shapes used in the canned computations.
@@ -170,12 +170,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,9 +187,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2],
+ EXPECT_EQ(schedule.sequence(computation)
+ .instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3],
+ EXPECT_EQ(schedule.sequence(computation)
+ .instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -203,10 +205,10 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/20 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -242,10 +244,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/17 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -276,10 +278,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/15 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -316,10 +318,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/13 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -382,14 +384,14 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -476,13 +478,13 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -571,13 +573,13 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
new file mode 100644
index 0000000000..a65b33bf40
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -0,0 +1,291 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <queue>
+#include <vector>
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+
+void HloSchedule::set_sequence(
+ const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence) {
+ set_sequence(computation, HloInstructionSequence(sequence));
+}
+
+void HloSchedule::set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence) {
+ CHECK(computation->parent() == module_);
+ sequences_[computation->unique_id()] = std::move(sequence);
+}
+
+HloInstructionSequence& HloSchedule::GetOrCreateSequence(
+ const HloComputation* computation) {
+ auto it = sequences_.find(computation->unique_id());
+ if (it == sequences_.end()) {
+ // No sequence found for computation. Create and return an empty one.
+ CHECK(computation->parent() == module_);
+ return sequences_[computation->unique_id()];
+ } else {
+ return it->second;
+ }
+}
+
+const HloInstructionSequence& HloSchedule::sequence(
+ const HloComputation* computation) const {
+ return sequences_.at(computation->unique_id());
+}
+
+Status HloSchedule::UpdateComputationSchedule(
+ const HloComputation* computation) {
+ // Map from unique ID to HloInstruction pointer for instructions in the
+ // computation.
+ tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
+ }
+
+ // Set of all HloInstructions in the schedule.
+ tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ for (int id : sequences_.at(computation->unique_id()).ids()) {
+ InsertOrDie(&ids_in_schedule, id);
+ }
+
+ // Map from HloInstruction X to newly added instructions (instruction is in
+ // computation, but not in schedule) which use X. If an instruction is not in
+ // the map, then it has no users which are newly added instructions.
+ tensorflow::gtl::FlatMap<const HloInstruction*,
+ std::vector<const HloInstruction*>>
+ new_instruction_uses;
+
+ // For each newly added instruction, this is the count of the instruction's
+ // operands that have not yet been scheduled. When this value reaches zero,
+ // then the instruction may be placed in the schedule.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int>
+ unscheduled_operand_count;
+
+ // Create a worklist of newly added instructions which are ready to be added
+ // to the schedule. Initialize worklist with those that have zero operands.
+ std::queue<const HloInstruction*> worklist;
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ if (ids_in_schedule.count(instruction->unique_id()) == 0) {
+ // This is a newly added instruction which is not in the schedule.
+ if (instruction->operands().empty()) {
+ worklist.push(instruction);
+ } else {
+ for (const HloInstruction* operand : instruction->operands()) {
+ new_instruction_uses[operand].push_back(instruction);
+ }
+ unscheduled_operand_count[instruction] = instruction->operand_count();
+ }
+ }
+ }
+
+ // Update the schedule with the newly added instructions, and remove any
+ // instructions no longer in the graph.
+ HloInstructionSequence new_sequence;
+
+ // Lambda which schedules all instructions on the worklist.
+ auto schedule_worklist = [&]() {
+ while (!worklist.empty()) {
+ const HloInstruction* instruction = worklist.front();
+ worklist.pop();
+ new_sequence.push_back(instruction);
+ std::vector<const HloInstruction*>* new_users =
+ tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
+ if (new_users != nullptr) {
+ // This just-scheduled instruction has users which are newly added to
+ // the module. Update the number of unscheduled operands and push the
+ // newly added instruction to the worklist if it is ready to
+ // schedule.
+ for (const HloInstruction* new_user : *new_users) {
+ unscheduled_operand_count.at(new_user)--;
+ CHECK_GE(unscheduled_operand_count.at(new_user), 0);
+ if (unscheduled_operand_count.at(new_user) == 0) {
+ worklist.push(new_user);
+ }
+ }
+ }
+ }
+ };
+
+ schedule_worklist();
+ for (int id : sequences_.at(computation->unique_id()).ids()) {
+ auto it = id_to_instruction.find(id);
+ if (it == id_to_instruction.end()) {
+ // This instruction in the schedule is no longer in the module. Do not add
+ // it to the new schedule.
+ continue;
+ }
+ worklist.push(it->second);
+ schedule_worklist();
+ }
+
+ set_sequence(computation, std::move(new_sequence));
+ return Status::OK();
+}
+
+Status HloSchedule::Update() {
+ // The schedule must contain a sequence for every non-fusion computation in
+ // the module, but can have sequences for computations which no longer exist
+ // (these are removed).
+ std::vector<HloComputation*> nonfusion_computations =
+ module_->MakeNonfusionComputations();
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+ << "Computation " << computation->name() << " not in HloSchedule.";
+ }
+ if (sequences_.size() > nonfusion_computations.size()) {
+ // Schedule contains some computations which have been removed from the
+ // HloModule. Remove them from the schedule as well.
+ tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+ for (const HloComputation* computation : nonfusion_computations) {
+ nonfusion_computations_ids.insert(computation->unique_id());
+ }
+ for (auto it = sequences_.begin(); it != sequences_.end();) {
+ if (nonfusion_computations_ids.count(it->first) == 0) {
+ it = sequences_.erase(it);
+ } else {
+ it++;
+ }
+ }
+ }
+ CHECK_EQ(sequences_.size(), nonfusion_computations.size());
+
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
+ }
+
+ TF_RETURN_IF_ERROR(Verify());
+ return Status::OK();
+}
+
+Status HloSchedule::Verify() const {
+ VLOG(2) << "VerifySchedule()";
+ XLA_VLOG_LINES(3, module_->ToString());
+ XLA_VLOG_LINES(2, ToString());
+
+ // Verify schedule contains exactly the same set of non-fusion computations as
+ // module currently does.
+ std::vector<HloComputation*> nonfusion_computations =
+ module_->MakeNonfusionComputations();
+ TF_RET_CHECK(nonfusion_computations.size() == sequences_.size())
+ << "Schedule has " << sequences_.size() << " sequences, but module has "
+ << nonfusion_computations.size() << " non-fusion computations";
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+ << "Computation " << computation->name()
+ << " missing from HLO schedule.";
+ }
+
+ // For each computation verify the set of instructions is the same and that
+ // each dependency and control edge is honored.
+ for (const HloComputation* computation : nonfusion_computations) {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+ int pos = 0;
+ for (const HloInstruction* instruction :
+ sequence(computation).instructions()) {
+ TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
+ << "Instruction " << instruction->name()
+ << " appears more than once in the schedule";
+ pos++;
+ }
+
+ TF_RET_CHECK(instruction_position.size() ==
+ computation->instruction_count());
+ for (const HloInstruction* instruction : computation->instructions()) {
+ TF_RET_CHECK(instruction_position.count(instruction) == 1)
+ << "Instruction " << instruction->name() << " is not in schedule";
+ }
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ for (const HloInstruction* operand : instruction->operands()) {
+ TF_RET_CHECK(instruction_position.at(operand) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its operand " << operand->name();
+ }
+
+ for (const HloInstruction* pred : instruction->control_predecessors()) {
+ TF_RET_CHECK(instruction_position.at(pred) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its control predecessor "
+ << pred->name();
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+namespace {
+
+// Returns the computation in the given module with the given unique ID. Returns
+// nullptr if no such computation exists.
+const HloComputation* IdToComputation(const HloModule* module, int64 id) {
+ for (const HloComputation* computation : module->computations()) {
+ if (computation->unique_id() == id) {
+ return computation;
+ }
+ }
+ return nullptr;
+}
+
+} // namespace
+
+string HloSchedule::ToString() const {
+ std::vector<string> pieces;
+
+ pieces.push_back("HloSchedule");
+ for (const auto& id_sequence : sequences_) {
+ const HloComputation* computation =
+ IdToComputation(module_, id_sequence.first);
+ if (computation == nullptr) {
+ // The computation is not in the module and may have been deleted so it is
+ // not safe to dereference any HLO pointers. Just use the HLO unique ids
+ // stored in this object.
+ pieces.push_back(
+ absl::StrFormat("computation with id %d (no longer in HLO module):",
+ id_sequence.first));
+ for (int id : id_sequence.second.ids()) {
+ pieces.push_back(absl::StrCat(" ", id));
+ }
+ } else {
+ pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
+ for (const HloInstruction* instruction :
+ id_sequence.second.instructions()) {
+ pieces.push_back(absl::StrCat(" ", instruction->name()));
+ }
+ }
+ }
+ return absl::StrJoin(pieces, "\n");
+}
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
+ out << schedule.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
new file mode 100644
index 0000000000..21c6988638
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+
+// Class representing a sequence of HLO instructions such as the sequential
+// execution order of an HLO computation.
+class HloInstructionSequence {
+ public:
+ HloInstructionSequence() = default;
+ HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ for (const HloInstruction* instruction : instructions) {
+ push_back(instruction);
+ }
+ }
+
+ // Adds the instruction to the end of the sequence.
+ void push_back(const HloInstruction* instruction) {
+ instruction_sequence_.push_back(instruction);
+ id_sequence_.push_back(instruction->unique_id());
+ }
+
+ // Clears the sequence of all instructions.
+ void clear() {
+ instruction_sequence_.clear();
+ id_sequence_.clear();
+ }
+
+ int64 size() const { return instruction_sequence_.size(); }
+
+ // Returns the sequence of HLO instructions.
+ const std::vector<const HloInstruction*>& instructions() const {
+ return instruction_sequence_;
+ }
+
+ // Returns the unique IDs of the instructions in the sequence (in order).
+ const std::vector<int>& ids() const { return id_sequence_; }
+
+ private:
+ // The sequence as HloInstructions.
+ std::vector<const HloInstruction*> instruction_sequence_;
+
+ // The sequence of HLO instructions, represented by their unique IDs. The
+ // sequence is stored as both HloInstructions and unique IDs because the
+ // sequence may be referenced after transformations to the HLO graph and HLO
+ // pointers can be invalidated or recycled in this process (see
+ // HloSchedule::Update).
+ std::vector<int> id_sequence_;
+};
+
+// A class representing a sequential schedule of instructions for an HLO
+// module. A complete HLO schedule contains an instruction sequence for every
+// non-fusion computation in the HLO module.
+class HloSchedule {
+ public:
+ HloSchedule(const HloModule* module) : module_(module) {}
+
+ // Returns a reference to the sequence for the given computation.
+ const HloInstructionSequence& sequence(
+ const HloComputation* computation) const;
+
+ // Returns the sequence for the given computation. An empty sequence is
+ // created if none exists for the computation.
+ HloInstructionSequence& GetOrCreateSequence(
+ const HloComputation* computation);
+
+ // Sets the sequence for the given computation to the given sequence.
+ void set_sequence(const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence);
+ void set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence);
+
+ // Returns a map from HloComputation unique ID to instruction sequence. The
+ // map contains all sequences in the schedule.
+ const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
+ const {
+ return sequences_;
+ }
+
+ // Returns true if the schedule has a sequence for the given computation.
+ bool is_computation_scheduled(const HloComputation* computation) const {
+ return sequences_.count(computation->unique_id()) == 1;
+ }
+
+ // Updates the schedule such that it is (again) a valid schedule for the
+ // module. This is used to update a schedule after the HLO module has been
+ // transformed in some way. In general, the only transformations to the module
+ // for which a schedule can be updated is the addition or removal of
+ // instructions and removal of computations. Updating the schedule after new
+ // dependencies between existing instructions in the module is not supported
+ // and may result in an error status returned.
+ //
+ // Instructions in the module which also exist in the given schedule will
+ // remain in the same order in the updated schedule. Instructions which exist
+ // in the module but not in the given schedule will be placed as early as
+ // possible in the updated schedule.
+ Status Update();
+
+ // Verifies that the given schedule is valid for the given module.
+ // Specifically, the schedule contains exactly the instructions in the
+ // non-fusion computations in the module and every dependency in the module is
+ // satisfied in the schedule.
+ Status Verify() const;
+
+ string ToString() const;
+
+ bool empty() const { return sequences_.empty(); }
+
+ const HloModule* module() const { return module_; }
+
+ private:
+ // Updates the instruction sequence for the given computation.
+ Status UpdateComputationSchedule(const HloComputation* computation);
+
+ const HloModule* module_;
+
+ // A map from computation unique ID to instruction sequence. Unique IDs are
+ // used rather than HloComputation pointers because HLO pointers are not
+ // unique across HLO transformations because pointers may be recycled.
+ tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
new file mode 100644
index 0000000000..eb52582bb5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloScheduleTest : public HloTestBase {};
+
+TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
+ // Updating the schedule of an unchanged HLO module should not affect the
+ // schedule at all.
+ const string module_str = R"(
+HloModule UpdateScheduleUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ const std::vector<const HloInstruction*>& entry_schedule =
+ schedule.sequence(module->entry_computation()).instructions();
+
+ EXPECT_EQ(entry_schedule.size(), 6);
+
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(entry_schedule,
+ schedule.sequence(module->entry_computation()).instructions());
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
+ // Add some additional instructions to a module and verify the schedule can be
+ // updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithNewInstructions
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ HloComputation* entry = module->entry_computation();
+ const Shape shape = entry->root_instruction()->shape();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
+ entry->set_root_instruction(sub);
+
+ auto in_schedule = [&](const HloInstruction* hlo) {
+ return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
+ };
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 6);
+ EXPECT_FALSE(in_schedule(constant));
+ EXPECT_FALSE(in_schedule(sub));
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 8);
+ EXPECT_TRUE(in_schedule(constant));
+ EXPECT_TRUE(in_schedule(sub));
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
+ // Add and delete some instructions from a module and verify that the schedule
+ // can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithAddedAndDeletedInstruction
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ // Set the entry root to some expression containing just a parameter and a
+ // constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* new_root = entry->AddInstruction(
+ HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
+ constant, entry->parameter_instruction(0)));
+ entry->set_root_instruction(new_root);
+
+ // DCE should remove everything but the parameters and the newly added code.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 6);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 4);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
+ // Completely replace a module with an entirely new set of instructions and
+ // verify that the schedule can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithCompletelyReplacedModule
+
+ENTRY main {
+ a = f32[] constant(42.0)
+ b = f32[] constant(123.0)
+ ROOT sum = f32[] add(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ // Replace the entry computation with the negation of a constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ entry->set_root_instruction(new_root);
+
+ // DCE the old instructions.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 3);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 2);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
+ // Create changes to more than one computation in an HLO module and verify
+ // that the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+
+ const HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->operand(0);
+ HloComputation* body = xla_while->while_body();
+ HloComputation* cond = xla_while->while_condition();
+
+ // Negate the root of the cond.
+ cond->set_root_instruction(cond->AddInstruction(
+ HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kNot, cond->root_instruction())));
+
+ // Replace the body with a computation which just passes through its
+ // parameter.
+ body->set_root_instruction(body->parameter_instruction(0));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(body).size(), 7);
+ EXPECT_EQ(schedule.sequence(cond).size(), 4);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(body).size(), 1);
+ EXPECT_EQ(schedule.sequence(cond).size(), 5);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
+ // Remove computations from a module and verify the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+
+ HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->mutable_operand(0);
+ HloInstruction* init = xla_while->mutable_operand(0);
+
+ // Replace the while with its init value. The conditional and body
+ // computations should then be dead.
+ TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ ASSERT_EQ(module->computation_count(), 3);
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+ ASSERT_EQ(module->computation_count(), 1);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 0fc3b268c0..9bfb0af96c 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -70,7 +70,7 @@ class ListScheduler {
public:
// Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation.
- static StatusOr<std::vector<const HloInstruction*>> Run(
+ static StatusOr<HloInstructionSequence> Run(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -229,8 +229,8 @@ class ListScheduler {
return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
}
- std::vector<const HloInstruction*> CreateSchedule() {
- std::vector<const HloInstruction*> schedule;
+ HloInstructionSequence CreateSchedule() {
+ HloInstructionSequence schedule;
// Populate the ready list with instructions which have no operands or
// control predecessors.
@@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes(
return size;
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
+StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -392,7 +392,7 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
} // namespace
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -443,7 +443,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
// Construct a total order based on DFS post-order, visiting operands in
// decreasing cumulative extra user order, and next by cumulative size, with a
// tiebreaker by name for determinism.
- std::vector<const HloInstruction*> sequence;
+ HloInstructionSequence sequence;
FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
sequence.push_back(hlo);
return Status::OK();
@@ -463,7 +463,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
return sequence;
} // namespace xla
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -473,18 +473,16 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
}
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation) {
- const auto& post_order = computation.MakeInstructionPostOrder();
- return std::vector<const HloInstruction*>{post_order.begin(),
- post_order.end()};
+ return HloInstructionSequence(computation.MakeInstructionPostOrder());
}
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -499,7 +497,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
// List wins for most of our benchmarks; postorder-based schedulers win for
// some RNNs.
TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> list_sequence,
+ HloInstructionSequence list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 list_memory,
@@ -508,7 +506,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
- TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis,
size_function, memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
@@ -518,7 +516,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> post_order_sequence,
+ HloInstructionSequence post_order_sequence,
PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
@@ -545,32 +543,35 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
}
}
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm) {
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(&module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
- TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
HeapSimulator::MinimumMemoryForComputation(
- *computation, one_computation_sequence, *points_to_analysis,
+ *computation, computation_sequence, *points_to_analysis,
size_function, &memory_by_computation)
.ValueOrDie();
- sequence[computation] = std::move(one_computation_sequence);
+ schedule.set_sequence(computation, std::move(computation_sequence));
}
}
- VLOG(1) << "Module schedule:\n" << sequence;
- return sequence;
+ VLOG(1) << "Module schedule:\n" << schedule;
+
+ TF_RETURN_IF_ERROR(schedule.Verify());
+
+ return std::move(schedule);
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation());
@@ -581,187 +582,4 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
size_function, nullptr, empty_map);
}
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) {
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> id_sequence;
- for (const auto& computation_sequence : sequence) {
- for (const HloInstruction* instruction : computation_sequence.second) {
- id_sequence[computation_sequence.first].push_back(
- instruction->unique_id());
- }
- }
- return id_sequence;
-}
-
-Status UpdateSchedule(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
- id_sequence,
- SequentialHloOrdering::HloModuleSequence* sequence) {
- // Map from unique ID to HloInstruction pointer for instructions in the
- // module.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
- // Set of all HloInstructions in the schedule.
- tensorflow::gtl::FlatSet<int> ids_in_schedule;
- std::vector<HloComputation*> nonfusion_computations =
- module.MakeNonfusionComputations();
- for (const HloComputation* computation : nonfusion_computations) {
- for (const HloInstruction* instruction : computation->instructions()) {
- TF_RET_CHECK(
- id_to_instruction.insert({instruction->unique_id(), instruction})
- .second);
- }
- for (int id : id_sequence.at(computation)) {
- ids_in_schedule.insert(id);
- }
- }
-
- // Map from HloInstruction X to newly added instructions (instruction is in
- // module, but not in schedule) which use X. If an instruction is not in the
- // map, then it has no users which are newly added instructions.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::vector<const HloInstruction*>>
- new_instruction_uses;
-
- // For each newly added instruction, this is the count of the instruction's
- // operands that have not yet been scheduled. When this value reaches zero,
- // then the instruction may be placed in the schedule.
- tensorflow::gtl::FlatMap<const HloInstruction*, int>
- unscheduled_operand_count;
- // For each computation, this is the set of newly added instructions which
- // have no operands. These must be handled specially and are added to the
- // beginning of the schedule.
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::vector<const HloInstruction*>>
- new_zero_operand_instructions;
- for (const HloComputation* computation : nonfusion_computations) {
- new_zero_operand_instructions[computation] = {};
- for (const HloInstruction* instruction : computation->instructions()) {
- if (ids_in_schedule.count(instruction->unique_id()) == 0) {
- // This is a newly added instruction which is not in the schedule.
- for (const HloInstruction* operand : instruction->operands()) {
- new_instruction_uses[operand].push_back(instruction);
- }
- if (instruction->operands().empty()) {
- new_zero_operand_instructions[computation].push_back(instruction);
- }
- unscheduled_operand_count[instruction] = instruction->operand_count();
- }
- }
- }
-
- // Update the schedule with the newly added instructions, and remove any
- // instructions no longer in the graph.
- for (const HloComputation* computation : nonfusion_computations) {
- std::vector<const HloInstruction*> old_computation_sequence =
- std::move(sequence->at(computation));
- sequence->at(computation).clear();
-
- // Create a worklist of newly added instructions which are ready to be added
- // to the schedule. Initialize worklist with those that have zero operands.
- std::queue<const HloInstruction*> worklist;
- for (const HloInstruction* instruction :
- new_zero_operand_instructions.at(computation)) {
- worklist.push(instruction);
- }
-
- // Lambda which schedules all instructions on the worklist.
- auto schedule_worklist = [&]() {
- while (!worklist.empty()) {
- const HloInstruction* instruction = worklist.front();
- worklist.pop();
- sequence->at(computation).push_back(instruction);
- std::vector<const HloInstruction*>* new_users =
- tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
- if (new_users != nullptr) {
- // This just-scheduled instruction has users which are newly added to
- // the module. Update the number of unscheduled operands and push the
- // newly added instruction to the worklist if it is ready to
- // schedule.
- for (const HloInstruction* new_user : *new_users) {
- unscheduled_operand_count.at(new_user)--;
- CHECK_GE(unscheduled_operand_count.at(new_user), 0);
- if (unscheduled_operand_count.at(new_user) == 0) {
- worklist.push(new_user);
- }
- }
- }
- }
- };
-
- schedule_worklist();
- for (int id : id_sequence.at(computation)) {
- auto it = id_to_instruction.find(id);
- if (it == id_to_instruction.end()) {
- // This instruction in the schedule is no longer in the module.
- continue;
- }
- const HloInstruction* instruction = it->second;
- worklist.push(instruction);
- schedule_worklist();
- }
- }
-
- TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence));
- return Status::OK();
-}
-
-Status VerifySchedule(
- const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& sequence) {
- VLOG(2) << "VerifySchedule()";
- XLA_VLOG_LINES(2, module.ToString());
- VLOG(2) << sequence;
-
- // Verify the set of computations in the sequence is exactly the set of
- // computations in the module.
- std::vector<HloComputation*> nonfusion_computations =
- module.MakeNonfusionComputations();
- TF_RET_CHECK(nonfusion_computations.size() == sequence.size());
- tensorflow::gtl::FlatSet<const HloComputation*> computations_in_module(
- module.computations().begin(), module.computations().end());
- for (const auto& computation_sequence : sequence) {
- TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1);
- }
-
- // For each computation verify the set of instructions is the same and that
- // each dependency and control edge is honored.
- for (const HloComputation* computation : nonfusion_computations) {
- tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
- int pos = 0;
- for (const HloInstruction* instruction : sequence.at(computation)) {
- TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
- << "Instruction " << instruction->name()
- << " appears more than once in the schedule";
- pos++;
- }
-
- TF_RET_CHECK(instruction_position.size() ==
- computation->instruction_count());
- for (const HloInstruction* instruction : computation->instructions()) {
- TF_RET_CHECK(instruction_position.count(instruction) == 1)
- << "Instruction " << instruction->name() << " is not in schedule";
- }
-
- for (const HloInstruction* instruction : computation->instructions()) {
- for (const HloInstruction* operand : instruction->operands()) {
- TF_RET_CHECK(instruction_position.at(operand) <
- instruction_position.at(instruction))
- << "Instruction " << instruction->name()
- << " is not scheduled after its operand " << operand->name();
- }
-
- for (const HloInstruction* pred : instruction->control_predecessors()) {
- TF_RET_CHECK(instruction_position.at(pred) <
- instruction_position.at(instruction))
- << "Instruction " << instruction->name()
- << " is not scheduled after its control predecessor "
- << pred->name();
- }
- }
- }
-
- return Status::OK();
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index d06b8d9a5c..54e32340ba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -32,14 +33,14 @@ namespace xla {
// 'computation' that minimizes peak memory, given a points-to analysis result
// that describes buffer aliasing, together with a target-specific size function
// that maps a tensor's logical size to its padded size.
-typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
+typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -47,7 +48,7 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
// DFS-order scheduler
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -55,7 +56,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
memory_by_computation);
// Naive Post Order scheduler
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -65,63 +66,26 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
// The default scheduling algorithm. Runs both the list scheduler
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation);
-// Returns an HloModuleSequence which seeks to minimize the memory required for
+// Returns an HloSchedule which seeks to minimize the memory required for
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {});
// Computes the schedule for a single computation.
// Currently only used by the GPU backend.
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
-// Transforms the given schedule such that it is (again) a valid schedule for
-// the module. This is used to update a schedule after the HLO module has been
-// transformed in some way. In general, the only transformations to the module
-// for which a schedule can be updated is the addition or removal of
-// instructions to/from the module. Updating the schedule after new dependencies
-// between existing instructions in the module is not supported and may result
-// in an error status returned.
-//
-// Instructions in the module which also exist in the given schedule will remain
-// in the same order in the updated schedule. Instructions which exist in the
-// module but not in the given schedule will be placed as early as possible in
-// the updated schedule.
-//
-// 'id_sequence' is a mirror of the given schedule 'sequence' but with
-// HloInstruction ids rather than HloInstruction pointers. This should be
-// constructed using ComputeIdSchedule below after the schedule is constructed
-// but before the HLO module is transformed.
-Status UpdateSchedule(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
- id_sequence,
- SequentialHloOrdering::HloModuleSequence* sequence);
-
-// Constructs a copy of the given schedule but with HloInstruction unique ids
-// rather than HloInstruction pointers. This is necessary for updating a
-// schedule as HloInstruction points in the schedule may become invalid if
-// instructions are removed from the module. Used by UpdateSchedule above..
-// TODO(b/113175018): Remove this function when HLO schedule is its own class.
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence);
-
-// Verifies that the given schedule is valid for the given module. Specifically,
-// the schedule contains exactly the instructions in the module and every
-// dependency in the module is satisfied in the schedule.
-Status VerifySchedule(const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& sequence);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index d49d09d459..6afe51997e 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -67,19 +68,20 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ const std::vector<const HloInstruction*>& sequence =
+ schedule.sequence(module->entry_computation()).instructions();
+ EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
- EXPECT_EQ(param, sequence.at(module->entry_computation()).front());
- EXPECT_EQ(sub, sequence.at(module->entry_computation()).back());
+ EXPECT_EQ(param, sequence.front());
+ EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(module.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
}
@@ -108,28 +110,26 @@ ENTRY root {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ const std::vector<const HloInstruction*>& sequence =
+ schedule.sequence(module->entry_computation()).instructions();
+ EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
std::unordered_map<string, const HloInstruction*> instructions_by_name;
- for (const HloInstruction* instruction :
- sequence.at(module->entry_computation())) {
+ for (const HloInstruction* instruction : sequence) {
instructions_by_name[instruction->name()] = instruction;
}
// The first instruction should be the parameter and the last the root.
- EXPECT_EQ(instructions_by_name.at("param"),
- sequence.at(module->entry_computation()).front());
- EXPECT_EQ(instructions_by_name.at("result"),
- sequence.at(module->entry_computation()).back());
+ EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
+ EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
// Instructions "d" and "e" will both be schedulable at the same time, but
// instruction "d" allows us to free the buffer of "p1", so the list scheduler
// should prefer it.
- SequentialHloOrdering ordering(module.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
instructions_by_name.at("e")));
}
@@ -220,13 +220,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
EXPECT_EQ(entry_computation->instruction_count(),
- sequence.at(entry_computation).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(entry_computation).size());
+ SequentialHloOrdering ordering(schedule);
// This schedule is an example of List's greedy heuristics being suboptimal.
// The while_loop is more expensive than transpose, so it would have been
// better to schedule it first, instead of during the busy time.
@@ -243,13 +243,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
// HeapSimulator doesn't account for subcomputations
EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
// HeapSimulator accounts for subcomputations. The output buffer is aliased,
// so we don't double count.
EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
@@ -281,19 +281,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(
- buffer.shape(), TUPLE_SIZE);
- },
- ListMemoryScheduler));
+ TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+ ScheduleModule(*module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), TUPLE_SIZE);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(module->entry_computation()).size());
+ SequentialHloOrdering ordering(schedule);
// tuple allocates the tuple buffer and doesn't free anything.
// abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
// abs_abs2 should be scheduled before tuple by List.
@@ -332,18 +331,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
auto fusion = computation->CreateFusionInstruction(
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
- TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(
- *module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
- },
- ListMemoryScheduler));
+ TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+ ScheduleModule(*module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), 2);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(module->entry_computation()).size());
+ SequentialHloOrdering ordering(schedule);
// fusion allocates memory for the tuple elements and doesn't free anything,
// so it's more expensive than exp.
EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
@@ -391,12 +390,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
- EXPECT_EQ(entry_computation->instruction_count(),
- sequence.at(entry_computation).size());
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ schedule.sequence(module->entry_computation()).size());
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
memory_by_computation[cond_computation] = 17;
@@ -406,262 +405,16 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
// HeapSimulator doesn't account for subcomputations
EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
// HeapSimulator accounts for subcomputations. Cond is the largest one.
// The output buffer of the while is aliased.
EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
-TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) {
- // Updating the schedule of an unchanged HLO module should not affect the
- // schedule at all.
- const string module_str = R"(
-HloModule UpdateScheduleUnchanged
-
-ENTRY main {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- c = f32[] constant(42.0)
- sum = f32[] add(a, b)
- neg = f32[] negate(c)
- ROOT root = f32[] multiply(sum, neg)
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
- std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second;
-
- EXPECT_EQ(entry_schedule.size(), 6);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(entry_schedule, sequence.begin()->second);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) {
- // Add some additional instructions to a module and verify the schedule can be
- // updated.
- const string module_str = R"(
-HloModule UpdateScheduleWithNewInstructions
-
-ENTRY main {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- c = f32[] constant(42.0)
- sum = f32[] add(a, b)
- neg = f32[] negate(c)
- ROOT root = f32[] multiply(sum, neg)
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- HloComputation* entry = module->entry_computation();
- const Shape shape = entry->root_instruction()->shape();
- HloInstruction* constant = entry->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
- HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
- entry->set_root_instruction(sub);
-
- auto in_schedule = [&](const HloInstruction* hlo) {
- return std::find(sequence.at(entry).begin(), sequence.at(entry).end(),
- hlo) != sequence.at(entry).end();
- };
-
- EXPECT_EQ(sequence.at(entry).size(), 6);
- EXPECT_FALSE(in_schedule(constant));
- EXPECT_FALSE(in_schedule(sub));
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(entry).size(), 8);
- EXPECT_TRUE(in_schedule(constant));
- EXPECT_TRUE(in_schedule(sub));
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) {
- // Add and delete some instructions from a module and verify that the schedule
- // can be updated successfully.
- const string module_str = R"(
-HloModule UpdateScheduleWithAddedAndDeletedInstruction
-
-ENTRY main {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- c = f32[] constant(42.0)
- sum = f32[] add(a, b)
- neg = f32[] negate(c)
- ROOT root = f32[] multiply(sum, neg)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- // Set the entry root to some expression containing just a parameter and a
- // constant.
- HloComputation* entry = module->entry_computation();
- HloInstruction* constant = entry->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
- HloInstruction* new_root = entry->AddInstruction(
- HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
- constant, entry->parameter_instruction(0)));
- entry->set_root_instruction(new_root);
-
- // DCE should remove everything but the parameters and the newly added code.
- HloDCE dce;
- TF_ASSERT_OK(dce.Run(module.get()).status());
-
- EXPECT_EQ(sequence.at(entry).size(), 6);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(entry).size(), 4);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) {
- // Completely replace a module with an entirely new set of instructions and
- // verify that the schedule can be updated successfully.
- const string module_str = R"(
-HloModule UpdateScheduleWithCompletelyReplacedModule
-
-ENTRY main {
- a = f32[] constant(42.0)
- b = f32[] constant(123.0)
- ROOT sum = f32[] add(a, b)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- // Replace the entry computation with the negation of a constant.
- HloComputation* entry = module->entry_computation();
- HloInstruction* constant = entry->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
- HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kNegate, constant));
- entry->set_root_instruction(new_root);
-
- // DCE the old instructions.
- HloDCE dce;
- TF_ASSERT_OK(dce.Run(module.get()).status());
-
- EXPECT_EQ(sequence.at(entry).size(), 3);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(entry).size(), 2);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) {
- // Create changes to more than one computation in an HLO module and verify
- // that the schedule can be updated.
- const string module_str = R"(
-HloModule UpdateScheduleWithMultipleComputations
-
-%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
- %param.1 = (s32[], token[]) parameter(0)
- %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
- %constant.1 = s32[] constant(1)
- %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
- %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
- %after-all = token[] after-all(token[] %get-tuple-element.2)
- ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
-}
-
-%Cond (param: (s32[], token[])) -> pred[] {
- %param = (s32[], token[]) parameter(0)
- %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
- %constant = s32[] constant(42)
- ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
-}
-
-ENTRY %WhileLoop () -> s32[] {
- %zero = s32[] constant(0)
- %init_token = token[] after-all()
- %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
- %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
- ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(),
- /*pointer_size=*/sizeof(void*));
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- const HloInstruction* xla_while =
- module->entry_computation()->root_instruction()->operand(0);
- HloComputation* body = xla_while->while_body();
- HloComputation* cond = xla_while->while_condition();
-
- // Negate the root of the cond.
- cond->set_root_instruction(cond->AddInstruction(
- HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
- HloOpcode::kNot, cond->root_instruction())));
-
- // Replace the body with a computation which just passes through its
- // parameter.
- body->set_root_instruction(body->parameter_instruction(0));
-
- // DCE the dead code in the body.
- HloDCE dce;
- TF_ASSERT_OK(dce.Run(module.get()).status());
-
- EXPECT_EQ(sequence.at(body).size(), 7);
- EXPECT_EQ(sequence.at(cond).size(), 4);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(body).size(), 1);
- EXPECT_EQ(sequence.at(cond).size(), 5);
-}
-
} // namespace
} // namespace xla