aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-05 10:34:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 10:41:37 -0700
commit7fa693209fe238478739b3982f652a7e35be91f3 (patch)
treeeb31635c366d9eceb144970ddb2b659441204ce1
parent08313b87960962efb98bcd684776c8305fa9909a (diff)
Add HloSchedule class representing a sequential order of an HloModule.
Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encapsulation of code which deals with schedules and better enforcement of invariants. This CL also fixes a corner-case bug in dataflow analysis, where values of instructions which are live out of the computation erroneously did not interfere with the values of instructions scheduled after the root instruction. PiperOrigin-RevId: 211656888
-rw-r--r--tensorflow/compiler/xla/service/BUILD48
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc28
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc98
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc56
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h4
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc43
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h48
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc36
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc86
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc291
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h151
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc341
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc230
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h54
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc343
27 files changed, 1325 insertions, 905 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f6cfac6537..612302781c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -989,6 +989,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
@@ -1036,6 +1037,7 @@ tf_cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
+ ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -1049,6 +1051,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
@@ -1062,6 +1065,7 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
+ ":hlo_schedule",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1082,6 +1086,7 @@ tf_cc_test(
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
+ ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1089,6 +1094,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
],
)
@@ -1102,6 +1108,7 @@ cc_library(
":hlo",
":hlo_ordering",
":hlo_proto",
+ ":hlo_schedule",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1125,6 +1132,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
@@ -1170,6 +1178,43 @@ cc_library(
)
cc_library(
+ name = "hlo_schedule",
+ srcs = ["hlo_schedule.cc"],
+ hdrs = ["hlo_schedule.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_schedule_test",
+ srcs = ["hlo_schedule_test.cc"],
+ deps = [
+ ":heap_simulator",
+ ":hlo",
+ ":hlo_dce",
+ ":hlo_ordering",
+ ":hlo_parser",
+ ":hlo_schedule",
+ ":hlo_scheduling",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
+ ],
+)
+
+cc_library(
name = "hlo_scheduling",
srcs = ["hlo_scheduling.cc"],
hdrs = ["hlo_scheduling.h"],
@@ -1177,6 +1222,7 @@ cc_library(
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_schedule",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1205,6 +1251,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2366,6 +2413,7 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
+ ":hlo_schedule",
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 8b8c6bfd26..0f0af57626 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -617,18 +617,24 @@ Status BufferAssignment::ComputeSummaryStats() {
}
// Only compute total fragmentation if all computations have schedules.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(module_);
+ bool schedule_complete = true;
for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- liveness_->hlo_ordering().SequentialOrder(*computation);
- if (sequence != nullptr) {
- module_sequence.emplace(computation, *sequence);
+ if (!computation->IsFusionComputation()) {
+ const std::vector<const HloInstruction*>* sequence =
+ liveness_->hlo_ordering().SequentialOrder(*computation);
+ if (sequence == nullptr) {
+ schedule_complete = false;
+ } else {
+ schedule.set_sequence(computation, *sequence);
+ }
}
}
- if (module_sequence.size() == module_->computation_count()) {
+ if (schedule_complete) {
+ TF_RETURN_IF_ERROR(schedule.Verify());
TF_ASSIGN_OR_RETURN(
const int64 min_size,
- HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
+ HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
}
@@ -1064,7 +1070,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// since buffers for kCall, kWhile, and kConditional sub-computations are
// only live for the duration of their calling instructions.
VLOG(1) << "Running whole-module heap simulation";
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(&assignment->module());
FlatSet<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
@@ -1072,7 +1078,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
- module_sequence[computation] = *instruction_sequence;
+ schedule.set_sequence(computation, *instruction_sequence);
all_buffers_to_assign.insert(buffers_to_assign.begin(),
buffers_to_assign.end());
}
@@ -1090,7 +1096,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const HeapSimulator::Result result,
HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<LazyBestFitHeap>(alignment)),
- assignment->module(), module_sequence,
+ assignment->module(), schedule,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
@@ -1121,7 +1127,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
HeapSimulator::Run(
absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<LazyBestFitHeap>(alignment)),
- *computation, *instruction_sequence,
+ *computation, HloInstructionSequence(*instruction_sequence),
assignment->points_to_analysis(), assignment->buffer_size_,
options));
AssignBuffersFromHeapSimulator(result, assignment,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 7398f105a0..03e155fc11 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -120,14 +122,10 @@ class BufferAssignmentTest : public HloVerifiedTestBase {
HloModule* module,
absl::Span<const HloInstruction* const> instruction_sequence,
int64 alignment = 1) {
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[module->entry_computation()] =
- std::vector<const HloInstruction*>(instruction_sequence.begin(),
- instruction_sequence.end());
+ HloSchedule schedule(module);
+ schedule.set_sequence(module->entry_computation(), instruction_sequence);
return BufferAssigner::Run(
- module,
- absl::make_unique<SequentialHloOrdering>(module,
- module_sequence),
+ module, absl::make_unique<SequentialHloOrdering>(schedule),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -1785,11 +1783,10 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
- auto sequence =
- ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+ HloSchedule schedule =
+ ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
- module,
- absl::make_unique<SequentialHloOrdering>(module, sequence),
+ module, absl::make_unique<SequentialHloOrdering>(schedule),
ByteSizeOf,
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -2096,17 +2093,25 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// Create a sequential order among all the instructions in the entry
// computation, since the issue this test stresses depends on the order the
// nodes are traversed during BufferAssignment.
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence[module->entry_computation()] = {
- token, infeed, infeed_data, while0, while1, zero, add, while2, tuple};
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+ schedule.set_sequence(
+ module->entry_computation(),
+ {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
+ TF_ASSERT_OK(schedule.Verify());
+
TF_ASSERT_OK_AND_ASSIGN(
auto assignment,
- BufferAssigner::Run(
- module, absl::make_unique<SequentialHloOrdering>(module, sequence),
- backend().compiler()->BufferSizeBytesFunction(),
- [](LogicalBuffer::Color) { return 1; },
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ backend().compiler()->BufferSizeBytesFunction(),
+ [](LogicalBuffer::Color) { return 1; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// The result tuple elements must be assigned with different buffers.
TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
@@ -2263,29 +2268,6 @@ ENTRY Main {
GetAllocation(*buffers, param0, {1, 1}));
}
-static bool IsPostOrderTraversal(
- const std::vector<const HloInstruction*>& sequence) {
- tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
- auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
- return seen_so_far.count(instruction) == 0;
- };
-
- for (auto instruction : sequence) {
- if (std::any_of(instruction->operands().begin(),
- instruction->operands().end(), has_not_been_seen_yet) ||
- std::any_of(instruction->control_predecessors().begin(),
- instruction->control_predecessors().end(),
- has_not_been_seen_yet)) {
- return false; // Not a post order.
- }
- if (!seen_so_far.insert(instruction).second) {
- return false; // Not a "traversal".
- }
- }
-
- return true;
-}
-
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
@@ -2340,27 +2322,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
RunCopyInsertion(module);
- auto sequence =
- ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+ HloSchedule schedule =
+ ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
- // To trigger b/38494731, we want a specific Hlo sequence for the
+ // To trigger b/38494731, we want a specific Hlo schedule for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
- sequence[module->entry_computation()] = {
- input1, weights1, one, output1, while1->operand(0), while1,
- input0, weights0, zero, output0, while0->operand(0), while0,
- gte0, gte1, root_add};
+ schedule.set_sequence(module->entry_computation(),
+ {input1, weights1, one, output1, while1->operand(0),
+ while1, input0, weights0, zero, output0,
+ while0->operand(0), while0, gte0, gte1, root_add});
- // If this ASSERT_TRUE fails, we constructed a bogus sequence above
- // and this test itself is buggy.
- ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
+ // If this ASSERT fails, we constructed a bogus sequence above and this test
+ // itself is buggy.
+ TF_ASSERT_OK(schedule.Verify());
auto assignment =
- BufferAssigner::Run(
- module, absl::make_unique<SequentialHloOrdering>(module, sequence),
- ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true)
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true)
.ConsumeValueOrDie();
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 26e26e316d..414bfe7999 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
auto module = CreateNewModule();
HloComputation* entry = module->AddEntryComputation(builder.Build());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- auto liveness = BufferLiveness::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), sequence))
- .ConsumeValueOrDie();
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
+ .ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- SequentialHloOrdering::HloModuleSequence module_sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- module_sequence.emplace(computation, order);
- auto liveness = BufferLiveness::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(computation, {param, negate, exp, add});
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
+ .ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build(add));
- SequentialHloOrdering::HloModuleSequence module_sequence;
- std::vector<const HloInstruction*> order = {param, add, recv,
- recv_done, send, send_done};
- module_sequence.emplace(computation, order);
- auto liveness = BufferLiveness::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(computation,
+ {param, add, token, recv, recv_done, send, send_done});
+ TF_ASSERT_OK(schedule.Verify());
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
+ .ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Check the root instruction (add) buffer interferes with the recv buffer.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 796f36510e..e7b6075994 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -584,16 +584,14 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
- DFSMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
// Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(module.get(),
- absl::make_unique<SequentialHloOrdering>(
- module.get(), module_sequence),
+ absl::make_unique<SequentialHloOrdering>(schedule),
BufferSizeBytesFunction(), memory_alignment,
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true));
@@ -627,9 +625,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
}
TF_RETURN_IF_ERROR(
ir_emitter
- .EmitComputation(embedded_computation, embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
+ .EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &schedule.sequence(embedded_computation).instructions())
.status());
}
string function_name_prefix = entry_computation->name().empty()
@@ -637,9 +636,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
: entry_computation->name();
TF_ASSIGN_OR_RETURN(
llvm::Function * entry_function,
- ir_emitter.EmitComputation(entry_computation, function_name_prefix,
- /*is_top_level_computation=*/true,
- &module_sequence.at(entry_computation)));
+ ir_emitter.EmitComputation(
+ entry_computation, function_name_prefix,
+ /*is_top_level_computation=*/true,
+ &schedule.sequence(entry_computation).instructions()));
string function_name = [&]() {
llvm::SmallVector<char, 40> function_name_vector;
@@ -771,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
- TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- ScheduleComputationsInModule(*module, BufferSizeBytesFunction()));
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, BufferSizeBytesFunction()));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module,
- absl::make_unique<SequentialHloOrdering>(module, module_sequence),
- BufferSizeBytesFunction(), memory_alignment,
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -824,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
}
TF_RETURN_IF_ERROR(
ir_emitter
- .EmitComputation(embedded_computation,
- embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
+ .EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &schedule.sequence(embedded_computation).instructions())
.status());
}
const string& entry_point_name = options.entry_point_name();
- TF_ASSIGN_OR_RETURN(
- llvm::Function * entry_function,
- ir_emitter.EmitComputation(computation, entry_point_name,
- /*is_top_level_computation=*/true,
- &module_sequence.at(computation)));
+ TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
+ ir_emitter.EmitComputation(
+ computation, entry_point_name,
+ /*is_top_level_computation=*/true,
+ &schedule.sequence(computation).instructions()));
CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index e5cf15c686..df8c2a636b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -110,7 +110,7 @@ IrEmitter::IrEmitter(
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
- std::vector<const HloInstruction*>* instruction_order) {
+ const std::vector<const HloInstruction*>* instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 58a333b8fb..3df99464ba 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Function*> EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
- std::vector<const HloInstruction*>* instruction_order);
+ const std::vector<const HloInstruction*>* instruction_order);
llvm::IRBuilder<>* b() { return &b_; }
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a68b7a1bef..13ccff35f8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -813,6 +813,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
+ "//tensorflow/compiler/xla/service:hlo_schedule",
"//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index 743035a84e..ea9376e101 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
@@ -198,11 +199,12 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
// All kernels are launched on a single stream, so there's no loss of
// concurrency by optimizing for minimal memory usage.
TF_ASSIGN_OR_RETURN(
- schedule->thunk_launch_order_,
- ScheduleOneComputation(
+ HloInstructionSequence sequence,
+ ScheduleComputation(
*entry_computation, [pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
}));
+ schedule->thunk_launch_order_ = sequence.instructions();
} else {
// BFS tends to increase concurrency, but also increases memory usage.
BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
index 30a0e7cecd..07a7fc67aa 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
@@ -33,7 +33,9 @@ namespace gpu {
// launches, because thunks may be scheduled onto concurrent streams. This
// schedule is used by BufferAssigner to determine buffer liveness (i.e. to
// minimize allocations), and also by ThunkSchedule to determine the thunk
-// launch order.
+// launch order. This class differs from xla::HloSchedule in that HloSchedule
+// represents a total order of all instructions in the module for backends which
+// execute HLO instructions strictly sequentially.
class GpuHloSchedule {
public:
// Constructs an GpuHloSchedule for the given module, based on the given
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 38c3982ebf..e0f3a7e0e2 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function) {
- if (module_sequence.empty()) {
+ if (schedule.empty()) {
return 0;
}
- const HloModule* module = module_sequence.begin()->first->parent();
+ const HloModule* module = schedule.module();
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
@@ -47,14 +47,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
- module_sequence, *points_to_analysis, size_function));
+ schedule, *points_to_analysis, size_function));
return result.heap_size;
}
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
@@ -71,13 +70,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options) {
- HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
+ HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
const HloComputation* entry_computation = module.entry_computation();
- const std::vector<const HloInstruction*>& instruction_sequence =
- FindOrDie(module_sequence, entry_computation);
+ const HloInstructionSequence& instruction_sequence =
+ schedule.sequence(entry_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(
*entry_computation, instruction_sequence, points_to_analysis));
return heap.Finish();
@@ -86,13 +85,13 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr, memory_by_computation);
+ /*schedule=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -102,7 +101,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// 'instruction_sequence'.
Status HeapSimulator::RunComputation(
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis) {
VLOG(3) << "Computation:\n" << computation.ToString();
// The goal here is to minimize memory usage, assuming the given sequential
@@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation(
// set of instructions that need to be visited contains all users of all
// aliases, that is, all users of all instructions that have the buffer
// contained in their points-to set.
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const PointsToSet& points_to =
points_to_analysis.GetPointsToSet(instruction);
const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
@@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation(
std::vector<const BufferValue*> dead_buffers_to_free;
std::vector<const BufferValue*> operand_buffers_to_free;
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const TuplePointsToAnalysis::BufferDefinitionVector&
buffers_defined_by_instruction =
points_to_analysis.GetBuffersDefinedByInstruction(instruction);
@@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation(
// The order that the sub-computations are simulated does not affect
// correctness; since the whole module has been scheduled, we know that the
// sub-computations will never be run concurrently.
- if (module_sequence_ != nullptr) {
+ if (schedule_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kConditional ||
instruction->opcode() == HloOpcode::kWhile) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
- const std::vector<const HloInstruction*>& called_sequence =
- FindOrDie(*module_sequence_, called_computation);
+ const HloInstructionSequence& called_sequence =
+ schedule_->sequence(called_computation);
TF_RETURN_IF_ERROR(RunComputation(
*called_computation, called_sequence, points_to_analysis));
}
@@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const HloSchedule* schedule,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation)
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence),
+ schedule_(schedule),
memory_by_computation_(memory_by_computation) {
- debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
+ debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
}
HeapSimulator::~HeapSimulator() {}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index af05bedee7..ffbf947d5a 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -88,23 +89,22 @@ class HeapSimulator {
// Returns the minimum memory required to compute an HLO module where all
// computations have been scheduled (represented by the given
- // module_sequence), assuming no fragmentation.
+ // schedule), assuming no fragmentation.
static StatusOr<int64> MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function);
// Returns the minimum memory required to compute the given computation,
// assuming no fragmentation.
static StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation = nullptr);
// Run the heap simulation with the given algorithm, assuming the given
- // module_sequence, which must contain a topologically-consistent total
+ // schedule, which must contain a topologically-consistent total
// ordering of all instructions within each computation. The result is invalid
// if instructions are not run in exactly this sequence.
//
@@ -112,12 +112,12 @@ class HeapSimulator {
// to running on a per-computation basis, since we can re-use buffer space for
// called sub-computations.
//
- static StatusOr<Result> Run(
- std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn,
- const Options& options = Options());
+ static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm,
+ const HloModule& module,
+ const HloSchedule& schedule,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options = Options());
// Same as above, but runs on a single computation. The 'instruction_sequence'
// must contain a topologically-consistent total ordering of all instructions
@@ -126,7 +126,7 @@ class HeapSimulator {
static StatusOr<Result> Run(
std::unique_ptr<HeapAlgorithm> algorithm,
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
const Options& options = Options(),
@@ -134,21 +134,19 @@ class HeapSimulator {
memory_by_computation = nullptr);
private:
- // If 'module_sequence' is non-null, it is used to find kCall and kWhile
+ // If 'schedule' is non-null, it is used to find kCall and kWhile
// sub-computations, and the heap simulation for those sub-computations will
// be run recursively. I.e. the simulation is run over the whole module.
- HeapSimulator(
- std::unique_ptr<HeapAlgorithm> algorithm,
- const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
- memory_by_computation = nullptr);
+ HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options, const HloSchedule* schedule = nullptr,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
~HeapSimulator();
- Status RunComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
- const TuplePointsToAnalysis& points_to_analysis);
+ Status RunComputation(const HloComputation& computation,
+ const HloInstructionSequence& instruction_sequence,
+ const TuplePointsToAnalysis& points_to_analysis);
bool IgnoreBuffer(const BufferValue* buffer) const;
void Alloc(const BufferValue* buffer, const HloInstruction* instruction);
@@ -169,11 +167,11 @@ class HeapSimulator {
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
- // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+ // schedule_ is set by buffer assignment, and memory_by_computation_ is
// set by hlo scheduling. Then, in RunComputation, we check both in order to
// handle subcomputations. It would be good to unify the handling of
// subcomputations, but it's not clear how.
- const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+ const HloSchedule* schedule_;
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation_;
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 576c5ff7a4..1d98c45567 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
- cond_lt};
- module_sequence[body_computation] = {body_param};
- module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
- .ValueOrDie());
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(cond_computation,
+ {cond_param, cond_iter, cond_data, cond_lt});
+ schedule.set_sequence(body_computation, {body_param});
+ schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(
+ 56,
+ HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
}
const char kAlloc[] = "Alloc";
@@ -149,10 +153,11 @@ class HeapSimulatorTracker {
auto zero_size = [](const BufferValue& buffer) { return 0; };
auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(
- std::move(algorithm), *module_->entry_computation(),
- instruction_sequence, *points_to_analysis_, zero_size)
- .ConsumeValueOrDie();
+ result_ =
+ HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
+ HloInstructionSequence(instruction_sequence),
+ *points_to_analysis_, zero_size)
+ .ConsumeValueOrDie();
}
explicit HeapSimulatorTracker(const string& name) {
@@ -168,11 +173,12 @@ class HeapSimulatorTracker {
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
// Construct the module sequence grouped by computation.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(module_.get());
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
- module_sequence[instruction->parent()].push_back(instruction);
+ schedule.GetOrCreateSequence(instruction->parent())
+ .push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i;
}
@@ -185,8 +191,8 @@ class HeapSimulatorTracker {
};
auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
absl::make_unique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(std::move(algorithm), *module_,
- module_sequence, *points_to_analysis_, size_fn)
+ result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
+ *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
}
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index 54abe3345d..0cd0ab36fc 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -885,18 +885,20 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
// For a sequential order, if there is interference iff the negate is after
// the while.
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence[body] = {body_param, body_root};
- sequence[condition] = {cond_param, cond_root};
+ HloSchedule schedule(module_);
+ schedule.set_sequence(body, {body_param, body_root});
+ schedule.set_sequence(condition, {cond_param, cond_root});
{
- sequence[entry] = {init, xla_while, negate, entry_root};
- SequentialHloOrdering ordering(module_, sequence);
+ schedule.set_sequence(entry, {init, xla_while, negate, entry_root});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
{
- sequence[entry] = {init, negate, xla_while, entry_root};
- SequentialHloOrdering ordering(module_, sequence);
+ schedule.set_sequence(entry, {init, negate, xla_while, entry_root});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 62eea2b06c..0a86f83ed9 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
bool ssa_form = GetParam();
RunAnalysis(ssa_form);
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param, xla_while}});
- sequence.insert({condition, {cond_param, cond_constant}});
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, xla_while});
+ schedule.set_sequence(condition, {cond_param, cond_constant});
// Construct the order such that 'constant' and its use 'exp' are before
// body_param.
- sequence.insert({body, {constant, exp, body_param, add}});
+ schedule.set_sequence(
+ body, {constant, exp, body_param, add, dead_constant, dead_negate});
+ TF_ASSERT_OK(schedule.Verify());
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
// 'add' is live out of the body and will interfere with an later instructions
// such as 'dead_constant' and 'dead_negate'.
@@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- sequence.emplace(entry, order);
-
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, negate, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 0581d5c404..2105f7a349 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore(
VLOG(4) << a << " not defined before " << b;
return false;
}
+
+ if (a.live_out_of_module()) {
+ VLOG(4) << a << " is live out of module and defined before " << b;
+ return false;
+ }
+
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
@@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore(
return false;
}
}
+
+ if (a.instruction()->parent() == b.instruction()->parent()) {
+ for (const HloPosition& position : a.positions()) {
+ if (position.instruction ==
+ a.instruction()->parent()->root_instruction()) {
+ VLOG(4) << a << " is live out of computation and defined before " << b
+ << " which is in same computation";
+ return false;
+ }
+ }
+ }
+
return true;
}
@@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const {
return ToStringHelper("DependencyHloOrdering");
}
-SequentialHloOrdering::SequentialHloOrdering(
- const HloModule* module, const HloModuleSequence& module_sequence)
- : HloOrdering(module), module_sequence_(module_sequence) {
+SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
+ : HloOrdering(schedule.module()), schedule_(schedule) {
+ Initialize();
+}
+
+SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
+ : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
+ Initialize();
+}
+
+void SequentialHloOrdering::Initialize() {
// Create a map from instruction to its order position.
- for (auto computation_order : module_sequence_) {
- const std::vector<const HloInstruction*>& order = computation_order.second;
+ TF_DCHECK_OK(schedule_.Verify());
+ for (const auto& computation_sequence : schedule_.sequences()) {
+ const std::vector<const HloInstruction*>& order =
+ computation_sequence.second.instructions();
for (int i = 0; i < order.size(); ++i) {
- DCHECK_EQ(0, order_position_.count(order[i]));
- order_position_.emplace(order[i], i);
+ InsertOrDie(&order_position_, order[i], i);
}
}
}
@@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
const std::vector<const HloInstruction*>*
SequentialHloOrdering::SequentialOrder(
const HloComputation& computation) const {
- auto find_it = module_sequence_.find(&computation);
- return find_it == module_sequence_.end() ? nullptr : &find_it->second;
+ return schedule_.is_computation_scheduled(&computation)
+ ? &schedule_.sequence(&computation).instructions()
+ : nullptr;
}
string SequentialHloOrdering::ToString() const {
- std::vector<string> pieces;
- pieces.push_back("SequentialHloOrdering");
- for (auto* computation : module_->computations()) {
- pieces.push_back(
- absl::StrFormat("computation %s order:", computation->name()));
- // Gather all instructions in the module sequence for this computation and
- // sort them by their position.
- std::vector<const HloInstruction*> instructions;
- for (auto& instruction_position : order_position_) {
- const HloInstruction* instruction = instruction_position.first;
- if (instruction->parent() == computation) {
- instructions.push_back(instruction);
- }
- }
- std::sort(instructions.begin(), instructions.end(),
- [this](const HloInstruction* a, const HloInstruction* b) {
- return order_position_.at(a) < order_position_.at(b);
- });
- for (auto instruction : instructions) {
- pieces.push_back(absl::StrFormat(" %s", instruction->name()));
- }
- }
- return absl::StrJoin(pieces, "\n");
-}
-
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence) {
- for (auto computation_pair : module_sequence) {
- const HloComputation* computation = computation_pair.first;
- const std::vector<const HloInstruction*>& computation_sequence =
- computation_pair.second;
- out << "Computation " << computation->name() << ":\n";
- for (auto* instruction : computation_sequence) {
- out << " " << instruction->name() << "\n";
- }
- }
- return out;
+ return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index 985f3fa64d..b21071c4b2 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -183,17 +184,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering {
// interference is reduced relative to DependencyHloOrdering.
class SequentialHloOrdering : public HloOrdering {
public:
- // TODO(dimvar): HloModuleSequence is not a good name because it sounds like
- // a sequence of modules, instead of a map of schedules for all computations
- // in a module. We should change it at some point.
- //
- // A sequence of instructions for each computation in the module.
- using HloModuleSequence =
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::vector<const HloInstruction*>>;
-
- SequentialHloOrdering(const HloModule* module,
- const HloModuleSequence& module_sequence);
+ SequentialHloOrdering(const HloSchedule& schedule);
+ SequentialHloOrdering(HloSchedule&& schedule);
~SequentialHloOrdering() override = default;
// Returns the sequential instruction order for the given computation.
@@ -203,10 +195,12 @@ class SequentialHloOrdering : public HloOrdering {
string ToString() const override;
protected:
+ void Initialize();
+
bool ExecutesBeforeInSameComputation(const HloInstruction* a,
const HloInstruction* b) const override;
- const HloModuleSequence module_sequence_;
+ const HloSchedule schedule_;
// The position of every instruction in the HLO module in its respective
// computation sequence (a value of zero indicates the instruction is first in
@@ -217,10 +211,6 @@ class SequentialHloOrdering : public HloOrdering {
tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
};
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 126d3a2d9c..6b6005e7a5 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -376,5 +378,104 @@ ENTRY root {
dataflow->GetValueDefinedAt(add_3)));
}
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of the module should interfere with values
+ // defined after the root instruction. That is:
+ //
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* entry =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param, root, dead});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of a computation should interfere with values
+ // defined after the root instruction of the computation. That is:
+ //
+ // subcomputation:
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // entry computation:
+ // %c = constant(42.0)
+ // ROOT %call = call({%c}), subcomputation
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto subbuilder = HloComputation::Builder(TestName() + ".sub");
+ HloInstruction* param = subbuilder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = subbuilder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = subbuilder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* subcomputation = module->AddEmbeddedComputation(
+ subbuilder.Build(/*root_instruction=*/root));
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
+ HloComputation* entry = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(subcomputation, {param, root, dead});
+ schedule.set_sequence(entry, {c, call});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c9629926ea..0a0a6a323e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -962,8 +962,7 @@ StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
}
StatusOr<bool> HloRematerialization::RematerializeComputation(
- HloComputation* computation,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ HloComputation* computation, HloSchedule* schedule,
int64 memory_limit_bytes) {
VLOG(1) << "Rematerializing computation " << computation->name()
<< " with limit " << HumanReadableNumBytes(memory_limit_bytes);
@@ -971,7 +970,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
CHECK(!ContainsKey(rematerialized_computations_, computation));
- InstructionList instruction_list(sequence->at(computation));
+ InstructionList instruction_list(
+ schedule->sequence(computation).instructions());
MemoryUsageTracker memory_tracker(computation, size_function_,
*points_to_analysis_, instruction_list);
bool changed = false;
@@ -1145,7 +1145,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
0, memory_limit_bytes - memory_tracker.memory_usage());
TF_ASSIGN_OR_RETURN(
bool subcomputation_changed,
- RematerializeComputation(called_computation, sequence,
+ RematerializeComputation(called_computation, schedule,
subcomputation_memory_limit_bytes));
changed |= subcomputation_changed;
}
@@ -1179,12 +1179,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
computation_peak_memory_.at(computation) = peak_memory;
// Update order to include rematerialized instructions.
- auto& dst = sequence->at(computation);
- dst.clear();
+ HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
+ sequence.clear();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
- dst.push_back(instruction);
+ sequence.push_back(instruction);
}
rematerialized_computations_.insert(computation);
@@ -1194,20 +1194,21 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(
- HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The sequence is constructed entirely by this method.
- TF_RET_CHECK(sequence->empty());
+StatusOr<bool> HloRematerialization::Run(HloModule* module,
+ HloSchedule* schedule,
+ int64 memory_limit_bytes,
+ RematerializationSizes* sizes,
+ CopyInsertion* copy_insertion) {
+ // The schedule is constructed entirely by this method.
+ TF_RET_CHECK(schedule->empty());
VLOG(1) << "HloRematerialization() with memory limit of "
<< HumanReadableNumBytes(memory_limit_bytes);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial sequence of HLO instructions.
- TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
- *module,
+ // Create initial schedule of HLO instructions.
+ TF_ASSIGN_OR_RETURN(*schedule,
+ ScheduleModule(*module,
[this](const BufferValue& buffer) {
return size_function_(buffer.shape());
},
@@ -1217,16 +1218,7 @@ StatusOr<bool> HloRematerialization::Run(
// ordering from the HLO schedule allows for more copies to be eliminated.
// TODO(b/80249101): Instead of a separate copy elision pass, use the
// ordering from the HLO schedule directly for copy insertion.
-
- // First create a copy of the schedule which contains HloInstruction unique
- // ids instead of HloInstruction*. This is necessary for updating the
- // schedule below.
- // TODO(b/113175018): Remove this when the HLO schedule is self-contained
- // and can update itself.
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(*sequence);
-
- SequentialHloOrdering ordering(module, *sequence);
+ SequentialHloOrdering ordering(*schedule);
TF_RETURN_IF_ERROR(
copy_insertion->RemoveUnnecessaryCopies(ordering, module));
@@ -1241,10 +1233,10 @@ StatusOr<bool> HloRematerialization::Run(
// The passes above can add and remove copies, update the schedule to
// account for these transformations. Newly added instructions will be
// placed ASAP in the schedule.
- TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence));
+ TF_RETURN_IF_ERROR(schedule->Update());
TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
- SequentialHloOrdering(module, *sequence), module));
+ SequentialHloOrdering(*schedule), module));
}
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
@@ -1271,12 +1263,13 @@ StatusOr<bool> HloRematerialization::Run(
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, sequence](const CallGraphNode& node) -> Status {
+ [this, schedule](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
- ComputePeakMemory(node.computation(),
- sequence->at(node.computation())));
+ ComputePeakMemory(
+ node.computation(),
+ schedule->sequence(node.computation()).instructions()));
}
return Status::OK();
},
@@ -1295,7 +1288,7 @@ StatusOr<bool> HloRematerialization::Run(
// Subcomputations called by the entry computation will also be
// rematerialized.
TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), sequence,
+ module->entry_computation(), schedule,
adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
@@ -1305,30 +1298,7 @@ StatusOr<bool> HloRematerialization::Run(
// After DCE, the module sequence may include instructions which no longer
// exist.
- for (const auto* computation : module->MakeNonfusionComputations()) {
- if (sequence->at(computation).size() != computation->instruction_count()) {
- // A size mismatch between the computation instruction count and the size
- // of the ordering of instructions can only be caused by DCE. Rebuild the
- // order by removing the deleted instructions from the order.
- tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set;
- for (const auto& instruction : computation->instructions()) {
- instruction_set.insert(instruction);
- }
- // Move the old order into a temporary vector, then build new order
- // inplace.
- std::vector<const HloInstruction*>& order = sequence->at(computation);
- std::vector<const HloInstruction*> old_order;
- using std::swap;
- swap(order, old_order);
- std::copy_if(old_order.begin(), old_order.end(),
- std::back_inserter(order),
- [&instruction_set](const HloInstruction* instruction) {
- return ContainsKey(instruction_set, instruction);
- });
- TF_RET_CHECK(sequence->at(computation).size() ==
- computation->instruction_count());
- }
- }
+ TF_RETURN_IF_ERROR(schedule->Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1366,11 +1336,10 @@ StatusOr<bool> HloRematerialization::Run(
/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
const HloRematerialization::ShapeSizeFunction& size_function,
int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule,
RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
+ return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes,
copy_insertion);
}
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ec004350a..fa0414b472 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -21,6 +21,7 @@
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -50,7 +51,7 @@ class HloRematerialization {
//
// hlo_module: HLO module to rematerialize instructions in.
//
- // sequence: Should point to an empty HloModuleSequence. Upon return
+ // schedule: Should point to an empty HloSchedule. Upon return
// contains the HLO instruction order which was used for
// rematerialization. This is the order in which HLO instructions should
// be emitted to minimize memory use.
@@ -75,8 +76,8 @@ class HloRematerialization {
static StatusOr<bool> RematerializeAndSchedule(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr);
+ HloSchedule* schedule, RematerializationSizes* sizes,
+ CopyInsertion* copy_insertion = nullptr);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@@ -87,10 +88,9 @@ class HloRematerialization {
// Runs rematerialization on the given module. Returns whether the module was
// changed. memory_limit is the target maximum peak memory usage by the
- // module. sequence should be an empty HloModuleSequence. Upon return sequence
+ // module. schedule should be an empty HloSchedule. Upon return sequence
// contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ StatusOr<bool> Run(HloModule* module, HloSchedule* schedule,
int64 memory_limit, RematerializationSizes* sizes,
CopyInsertion* copy_insertion);
@@ -98,10 +98,9 @@ class HloRematerialization {
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
// and inserted into 'order'.
- StatusOr<bool> RematerializeComputation(
- HloComputation* computation,
- SequentialHloOrdering::HloModuleSequence* sequence,
- int64 computation_memory_limit);
+ StatusOr<bool> RematerializeComputation(HloComputation* computation,
+ HloSchedule* schedule,
+ int64 memory_limit_bytes);
// Computes and returns the peak memory used by the given computation. The
// peak memory is the maximum total size of all live HLO instruction values at
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index ac8c97d380..83cb113bfb 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -141,13 +141,13 @@ class HloRematerializationTest : public HloTestBase {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
- StatusOr<bool> RunHloRematerialization(
- int64 memory_limit_bytes, HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence) {
+ StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+ HloModule* module,
+ HloSchedule* schedule) {
TF_EXPECT_OK(verifier().Run(module).status());
return HloRematerialization::RematerializeAndSchedule(
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- sequence, /*sizes=*/nullptr);
+ schedule, /*sizes=*/nullptr);
}
// Various shapes used in the canned computations.
@@ -170,12 +170,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,9 +187,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2],
+ EXPECT_EQ(schedule.sequence(computation)
+ .instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3],
+ EXPECT_EQ(schedule.sequence(computation)
+ .instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -203,10 +205,10 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/20 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -242,10 +244,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/17 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -276,10 +278,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/15 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -316,10 +318,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/13 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -382,14 +384,14 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -476,13 +478,13 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -571,13 +573,13 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024,
- module.get(), &sequence));
+ module.get(), &schedule));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
new file mode 100644
index 0000000000..a65b33bf40
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -0,0 +1,291 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <queue>
+#include <vector>
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+
+void HloSchedule::set_sequence(
+ const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence) {
+ set_sequence(computation, HloInstructionSequence(sequence));
+}
+
+void HloSchedule::set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence) {
+ CHECK(computation->parent() == module_);
+ sequences_[computation->unique_id()] = std::move(sequence);
+}
+
+HloInstructionSequence& HloSchedule::GetOrCreateSequence(
+ const HloComputation* computation) {
+ auto it = sequences_.find(computation->unique_id());
+ if (it == sequences_.end()) {
+ // No sequence found for computation. Create and return an empty one.
+ CHECK(computation->parent() == module_);
+ return sequences_[computation->unique_id()];
+ } else {
+ return it->second;
+ }
+}
+
+const HloInstructionSequence& HloSchedule::sequence(
+ const HloComputation* computation) const {
+ return sequences_.at(computation->unique_id());
+}
+
+Status HloSchedule::UpdateComputationSchedule(
+ const HloComputation* computation) {
+ // Map from unique ID to HloInstruction pointer for instructions in the
+ // computation.
+ tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
+ }
+
+ // Set of all HloInstructions in the schedule.
+ tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ for (int id : sequences_.at(computation->unique_id()).ids()) {
+ InsertOrDie(&ids_in_schedule, id);
+ }
+
+ // Map from HloInstruction X to newly added instructions (instruction is in
+ // computation, but not in schedule) which use X. If an instruction is not in
+ // the map, then it has no users which are newly added instructions.
+ tensorflow::gtl::FlatMap<const HloInstruction*,
+ std::vector<const HloInstruction*>>
+ new_instruction_uses;
+
+ // For each newly added instruction, this is the count of the instruction's
+ // operands that have not yet been scheduled. When this value reaches zero,
+ // then the instruction may be placed in the schedule.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int>
+ unscheduled_operand_count;
+
+ // Create a worklist of newly added instructions which are ready to be added
+ // to the schedule. Initialize worklist with those that have zero operands.
+ std::queue<const HloInstruction*> worklist;
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ if (ids_in_schedule.count(instruction->unique_id()) == 0) {
+ // This is a newly added instruction which is not in the schedule.
+ if (instruction->operands().empty()) {
+ worklist.push(instruction);
+ } else {
+ for (const HloInstruction* operand : instruction->operands()) {
+ new_instruction_uses[operand].push_back(instruction);
+ }
+ unscheduled_operand_count[instruction] = instruction->operand_count();
+ }
+ }
+ }
+
+ // Update the schedule with the newly added instructions, and remove any
+ // instructions no longer in the graph.
+ HloInstructionSequence new_sequence;
+
+ // Lambda which schedules all instructions on the worklist.
+ auto schedule_worklist = [&]() {
+ while (!worklist.empty()) {
+ const HloInstruction* instruction = worklist.front();
+ worklist.pop();
+ new_sequence.push_back(instruction);
+ std::vector<const HloInstruction*>* new_users =
+ tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
+ if (new_users != nullptr) {
+ // This just-scheduled instruction has users which are newly added to
+ // the module. Update the number of unscheduled operands and push the
+ // newly added instruction to the worklist if it is ready to
+ // schedule.
+ for (const HloInstruction* new_user : *new_users) {
+ unscheduled_operand_count.at(new_user)--;
+ CHECK_GE(unscheduled_operand_count.at(new_user), 0);
+ if (unscheduled_operand_count.at(new_user) == 0) {
+ worklist.push(new_user);
+ }
+ }
+ }
+ }
+ };
+
+ schedule_worklist();
+ for (int id : sequences_.at(computation->unique_id()).ids()) {
+ auto it = id_to_instruction.find(id);
+ if (it == id_to_instruction.end()) {
+ // This instruction in the schedule is no longer in the module. Do not add
+ // it to the new schedule.
+ continue;
+ }
+ worklist.push(it->second);
+ schedule_worklist();
+ }
+
+ set_sequence(computation, std::move(new_sequence));
+ return Status::OK();
+}
+
+Status HloSchedule::Update() {
+ // The schedule must contain a sequence for every non-fusion computation in
+ // the module, but can have sequences for computations which no longer exist
+ // (these are removed).
+ std::vector<HloComputation*> nonfusion_computations =
+ module_->MakeNonfusionComputations();
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+ << "Computation " << computation->name() << " not in HloSchedule.";
+ }
+ if (sequences_.size() > nonfusion_computations.size()) {
+ // Schedule contains some computations which have been removed from the
+ // HloModule. Remove them from the schedule as well.
+ tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+ for (const HloComputation* computation : nonfusion_computations) {
+ nonfusion_computations_ids.insert(computation->unique_id());
+ }
+ for (auto it = sequences_.begin(); it != sequences_.end();) {
+ if (nonfusion_computations_ids.count(it->first) == 0) {
+ it = sequences_.erase(it);
+ } else {
+ it++;
+ }
+ }
+ }
+ CHECK_EQ(sequences_.size(), nonfusion_computations.size());
+
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
+ }
+
+ TF_RETURN_IF_ERROR(Verify());
+ return Status::OK();
+}
+
+Status HloSchedule::Verify() const {
+ VLOG(2) << "VerifySchedule()";
+ XLA_VLOG_LINES(3, module_->ToString());
+ XLA_VLOG_LINES(2, ToString());
+
+ // Verify schedule contains exactly the same set of non-fusion computations as
+ // module currently does.
+ std::vector<HloComputation*> nonfusion_computations =
+ module_->MakeNonfusionComputations();
+ TF_RET_CHECK(nonfusion_computations.size() == sequences_.size())
+ << "Schedule has " << sequences_.size() << " sequences, but module has "
+ << nonfusion_computations.size() << " non-fusion computations";
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+ << "Computation " << computation->name()
+ << " missing from HLO schedule.";
+ }
+
+ // For each computation verify the set of instructions is the same and that
+ // each dependency and control edge is honored.
+ for (const HloComputation* computation : nonfusion_computations) {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+ int pos = 0;
+ for (const HloInstruction* instruction :
+ sequence(computation).instructions()) {
+ TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
+ << "Instruction " << instruction->name()
+ << " appears more than once in the schedule";
+ pos++;
+ }
+
+ TF_RET_CHECK(instruction_position.size() ==
+ computation->instruction_count());
+ for (const HloInstruction* instruction : computation->instructions()) {
+ TF_RET_CHECK(instruction_position.count(instruction) == 1)
+ << "Instruction " << instruction->name() << " is not in schedule";
+ }
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ for (const HloInstruction* operand : instruction->operands()) {
+ TF_RET_CHECK(instruction_position.at(operand) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its operand " << operand->name();
+ }
+
+ for (const HloInstruction* pred : instruction->control_predecessors()) {
+ TF_RET_CHECK(instruction_position.at(pred) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its control predecessor "
+ << pred->name();
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+namespace {
+
+// Returns the computation in the given module with the given unique ID. Returns
+// nullptr if no such computation exists.
+const HloComputation* IdToComputation(const HloModule* module, int64 id) {
+ for (const HloComputation* computation : module->computations()) {
+ if (computation->unique_id() == id) {
+ return computation;
+ }
+ }
+ return nullptr;
+}
+
+} // namespace
+
+string HloSchedule::ToString() const {
+ std::vector<string> pieces;
+
+ pieces.push_back("HloSchedule");
+ for (const auto& id_sequence : sequences_) {
+ const HloComputation* computation =
+ IdToComputation(module_, id_sequence.first);
+ if (computation == nullptr) {
+ // The computation is not in the module and may have been deleted so it is
+ // not safe to dereference any HLO pointers. Just use the HLO unique ids
+ // stored in this object.
+ pieces.push_back(
+ absl::StrFormat("computation with id %d (no longer in HLO module):",
+ id_sequence.first));
+ for (int id : id_sequence.second.ids()) {
+ pieces.push_back(absl::StrCat(" ", id));
+ }
+ } else {
+ pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
+ for (const HloInstruction* instruction :
+ id_sequence.second.instructions()) {
+ pieces.push_back(absl::StrCat(" ", instruction->name()));
+ }
+ }
+ }
+ return absl::StrJoin(pieces, "\n");
+}
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
+ out << schedule.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
new file mode 100644
index 0000000000..21c6988638
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+
+// Class representing a sequence of HLO instructions such as the sequential
+// execution order of an HLO computation.
+class HloInstructionSequence {
+ public:
+ HloInstructionSequence() = default;
+ HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ for (const HloInstruction* instruction : instructions) {
+ push_back(instruction);
+ }
+ }
+
+ // Adds the instruction to the end of the sequence.
+ void push_back(const HloInstruction* instruction) {
+ instruction_sequence_.push_back(instruction);
+ id_sequence_.push_back(instruction->unique_id());
+ }
+
+ // Clears the sequence of all instructions.
+ void clear() {
+ instruction_sequence_.clear();
+ id_sequence_.clear();
+ }
+
+ int64 size() const { return instruction_sequence_.size(); }
+
+ // Returns the sequence of HLO instructions.
+ const std::vector<const HloInstruction*>& instructions() const {
+ return instruction_sequence_;
+ }
+
+ // Returns the unique IDs of the instructions in the sequence (in order).
+ const std::vector<int>& ids() const { return id_sequence_; }
+
+ private:
+ // The sequence as HloInstructions.
+ std::vector<const HloInstruction*> instruction_sequence_;
+
+ // The sequence of HLO instructions, represented by their unique IDs. The
+ // sequence is stored as both HloInstructions and unique IDs because the
+ // sequence may be referenced after transformations to the HLO graph and HLO
+ // pointers can be invalidated or recycled in this process (see
+ // HloSchedule::Update).
+ std::vector<int> id_sequence_;
+};
+
+// A class representing a sequential schedule of instructions for an HLO
+// module. A complete HLO schedule contains an instruction sequence for every
+// non-fusion computation in the HLO module.
+class HloSchedule {
+ public:
+ HloSchedule(const HloModule* module) : module_(module) {}
+
+ // Returns a reference to the sequence for the given computation.
+ const HloInstructionSequence& sequence(
+ const HloComputation* computation) const;
+
+ // Returns the sequence for the given computation. An empty sequence is
+ // created if none exists for the computation.
+ HloInstructionSequence& GetOrCreateSequence(
+ const HloComputation* computation);
+
+ // Sets the sequence for the given computation to the given sequence.
+ void set_sequence(const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence);
+ void set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence);
+
+ // Returns a map from HloComputation unique ID to instruction sequence. The
+ // map contains all sequences in the schedule.
+ const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
+ const {
+ return sequences_;
+ }
+
+ // Returns true if the schedule has a sequence for the given computation.
+ bool is_computation_scheduled(const HloComputation* computation) const {
+ return sequences_.count(computation->unique_id()) == 1;
+ }
+
+ // Updates the schedule such that it is (again) a valid schedule for the
+ // module. This is used to update a schedule after the HLO module has been
+ // transformed in some way. In general, the only transformations to the module
+ // for which a schedule can be updated is the addition or removal of
+ // instructions and removal of computations. Updating the schedule after new
+ // dependencies between existing instructions in the module is not supported
+ // and may result in an error status returned.
+ //
+ // Instructions in the module which also exist in the given schedule will
+ // remain in the same order in the updated schedule. Instructions which exist
+ // in the module but not in the given schedule will be placed as early as
+ // possible in the updated schedule.
+ Status Update();
+
+ // Verifies that the given schedule is valid for the given module.
+ // Specifically, the schedule contains exactly the instructions in the
+ // non-fusion computations in the module and every dependency in the module is
+ // satisfied in the schedule.
+ Status Verify() const;
+
+ string ToString() const;
+
+ bool empty() const { return sequences_.empty(); }
+
+ const HloModule* module() const { return module_; }
+
+ private:
+ // Updates the instruction sequence for the given computation.
+ Status UpdateComputationSchedule(const HloComputation* computation);
+
+ const HloModule* module_;
+
+ // A map from computation unique ID to instruction sequence. Unique IDs are
+ // used rather than HloComputation pointers because HLO pointers are not
+ // unique across HLO transformations because pointers may be recycled.
+ tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
new file mode 100644
index 0000000000..eb52582bb5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloScheduleTest : public HloTestBase {};
+
+TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
+ // Updating the schedule of an unchanged HLO module should not affect the
+ // schedule at all.
+ const string module_str = R"(
+HloModule UpdateScheduleUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ const std::vector<const HloInstruction*>& entry_schedule =
+ schedule.sequence(module->entry_computation()).instructions();
+
+ EXPECT_EQ(entry_schedule.size(), 6);
+
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(entry_schedule,
+ schedule.sequence(module->entry_computation()).instructions());
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
+ // Add some additional instructions to a module and verify the schedule can be
+ // updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithNewInstructions
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ HloComputation* entry = module->entry_computation();
+ const Shape shape = entry->root_instruction()->shape();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
+ entry->set_root_instruction(sub);
+
+ auto in_schedule = [&](const HloInstruction* hlo) {
+ return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
+ };
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 6);
+ EXPECT_FALSE(in_schedule(constant));
+ EXPECT_FALSE(in_schedule(sub));
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 8);
+ EXPECT_TRUE(in_schedule(constant));
+ EXPECT_TRUE(in_schedule(sub));
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
+ // Add and delete some instructions from a module and verify that the schedule
+ // can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithAddedAndDeletedInstruction
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ // Set the entry root to some expression containing just a parameter and a
+ // constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* new_root = entry->AddInstruction(
+ HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
+ constant, entry->parameter_instruction(0)));
+ entry->set_root_instruction(new_root);
+
+ // DCE should remove everything but the parameters and the newly added code.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 6);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 4);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
+ // Completely replace a module with an entirely new set of instructions and
+ // verify that the schedule can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithCompletelyReplacedModule
+
+ENTRY main {
+ a = f32[] constant(42.0)
+ b = f32[] constant(123.0)
+ ROOT sum = f32[] add(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ // Replace the entry computation with the negation of a constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ entry->set_root_instruction(new_root);
+
+ // DCE the old instructions.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 3);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 2);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
+ // Create changes to more than one computation in an HLO module and verify
+ // that the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+
+ const HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->operand(0);
+ HloComputation* body = xla_while->while_body();
+ HloComputation* cond = xla_while->while_condition();
+
+ // Negate the root of the cond.
+ cond->set_root_instruction(cond->AddInstruction(
+ HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kNot, cond->root_instruction())));
+
+ // Replace the body with a computation which just passes through its
+ // parameter.
+ body->set_root_instruction(body->parameter_instruction(0));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(body).size(), 7);
+ EXPECT_EQ(schedule.sequence(cond).size(), 4);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(body).size(), 1);
+ EXPECT_EQ(schedule.sequence(cond).size(), 5);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
+ // Remove computations from a module and verify the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+
+ HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->mutable_operand(0);
+ HloInstruction* init = xla_while->mutable_operand(0);
+
+ // Replace the while with its init value. The conditional and body
+ // computations should then be dead.
+ TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ ASSERT_EQ(module->computation_count(), 3);
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+ ASSERT_EQ(module->computation_count(), 1);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 0fc3b268c0..9bfb0af96c 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -70,7 +70,7 @@ class ListScheduler {
public:
// Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation.
- static StatusOr<std::vector<const HloInstruction*>> Run(
+ static StatusOr<HloInstructionSequence> Run(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -229,8 +229,8 @@ class ListScheduler {
return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
}
- std::vector<const HloInstruction*> CreateSchedule() {
- std::vector<const HloInstruction*> schedule;
+ HloInstructionSequence CreateSchedule() {
+ HloInstructionSequence schedule;
// Populate the ready list with instructions which have no operands or
// control predecessors.
@@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes(
return size;
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
+StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -392,7 +392,7 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
} // namespace
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -443,7 +443,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
// Construct a total order based on DFS post-order, visiting operands in
// decreasing cumulative extra user order, and next by cumulative size, with a
// tiebreaker by name for determinism.
- std::vector<const HloInstruction*> sequence;
+ HloInstructionSequence sequence;
FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
sequence.push_back(hlo);
return Status::OK();
@@ -463,7 +463,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
return sequence;
} // namespace xla
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -473,18 +473,16 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
}
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation) {
- const auto& post_order = computation.MakeInstructionPostOrder();
- return std::vector<const HloInstruction*>{post_order.begin(),
- post_order.end()};
+ return HloInstructionSequence(computation.MakeInstructionPostOrder());
}
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -499,7 +497,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
// List wins for most of our benchmarks; postorder-based schedulers win for
// some RNNs.
TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> list_sequence,
+ HloInstructionSequence list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 list_memory,
@@ -508,7 +506,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
- TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis,
size_function, memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
@@ -518,7 +516,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> post_order_sequence,
+ HloInstructionSequence post_order_sequence,
PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
@@ -545,32 +543,35 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
}
}
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm) {
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(&module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
- TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
HeapSimulator::MinimumMemoryForComputation(
- *computation, one_computation_sequence, *points_to_analysis,
+ *computation, computation_sequence, *points_to_analysis,
size_function, &memory_by_computation)
.ValueOrDie();
- sequence[computation] = std::move(one_computation_sequence);
+ schedule.set_sequence(computation, std::move(computation_sequence));
}
}
- VLOG(1) << "Module schedule:\n" << sequence;
- return sequence;
+ VLOG(1) << "Module schedule:\n" << schedule;
+
+ TF_RETURN_IF_ERROR(schedule.Verify());
+
+ return std::move(schedule);
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation());
@@ -581,187 +582,4 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
size_function, nullptr, empty_map);
}
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) {
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> id_sequence;
- for (const auto& computation_sequence : sequence) {
- for (const HloInstruction* instruction : computation_sequence.second) {
- id_sequence[computation_sequence.first].push_back(
- instruction->unique_id());
- }
- }
- return id_sequence;
-}
-
-Status UpdateSchedule(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
- id_sequence,
- SequentialHloOrdering::HloModuleSequence* sequence) {
- // Map from unique ID to HloInstruction pointer for instructions in the
- // module.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
- // Set of all HloInstructions in the schedule.
- tensorflow::gtl::FlatSet<int> ids_in_schedule;
- std::vector<HloComputation*> nonfusion_computations =
- module.MakeNonfusionComputations();
- for (const HloComputation* computation : nonfusion_computations) {
- for (const HloInstruction* instruction : computation->instructions()) {
- TF_RET_CHECK(
- id_to_instruction.insert({instruction->unique_id(), instruction})
- .second);
- }
- for (int id : id_sequence.at(computation)) {
- ids_in_schedule.insert(id);
- }
- }
-
- // Map from HloInstruction X to newly added instructions (instruction is in
- // module, but not in schedule) which use X. If an instruction is not in the
- // map, then it has no users which are newly added instructions.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::vector<const HloInstruction*>>
- new_instruction_uses;
-
- // For each newly added instruction, this is the count of the instruction's
- // operands that have not yet been scheduled. When this value reaches zero,
- // then the instruction may be placed in the schedule.
- tensorflow::gtl::FlatMap<const HloInstruction*, int>
- unscheduled_operand_count;
- // For each computation, this is the set of newly added instructions which
- // have no operands. These must be handled specially and are added to the
- // beginning of the schedule.
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::vector<const HloInstruction*>>
- new_zero_operand_instructions;
- for (const HloComputation* computation : nonfusion_computations) {
- new_zero_operand_instructions[computation] = {};
- for (const HloInstruction* instruction : computation->instructions()) {
- if (ids_in_schedule.count(instruction->unique_id()) == 0) {
- // This is a newly added instruction which is not in the schedule.
- for (const HloInstruction* operand : instruction->operands()) {
- new_instruction_uses[operand].push_back(instruction);
- }
- if (instruction->operands().empty()) {
- new_zero_operand_instructions[computation].push_back(instruction);
- }
- unscheduled_operand_count[instruction] = instruction->operand_count();
- }
- }
- }
-
- // Update the schedule with the newly added instructions, and remove any
- // instructions no longer in the graph.
- for (const HloComputation* computation : nonfusion_computations) {
- std::vector<const HloInstruction*> old_computation_sequence =
- std::move(sequence->at(computation));
- sequence->at(computation).clear();
-
- // Create a worklist of newly added instructions which are ready to be added
- // to the schedule. Initialize worklist with those that have zero operands.
- std::queue<const HloInstruction*> worklist;
- for (const HloInstruction* instruction :
- new_zero_operand_instructions.at(computation)) {
- worklist.push(instruction);
- }
-
- // Lambda which schedules all instructions on the worklist.
- auto schedule_worklist = [&]() {
- while (!worklist.empty()) {
- const HloInstruction* instruction = worklist.front();
- worklist.pop();
- sequence->at(computation).push_back(instruction);
- std::vector<const HloInstruction*>* new_users =
- tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
- if (new_users != nullptr) {
- // This just-scheduled instruction has users which are newly added to
- // the module. Update the number of unscheduled operands and push the
- // newly added instruction to the worklist if it is ready to
- // schedule.
- for (const HloInstruction* new_user : *new_users) {
- unscheduled_operand_count.at(new_user)--;
- CHECK_GE(unscheduled_operand_count.at(new_user), 0);
- if (unscheduled_operand_count.at(new_user) == 0) {
- worklist.push(new_user);
- }
- }
- }
- }
- };
-
- schedule_worklist();
- for (int id : id_sequence.at(computation)) {
- auto it = id_to_instruction.find(id);
- if (it == id_to_instruction.end()) {
- // This instruction in the schedule is no longer in the module.
- continue;
- }
- const HloInstruction* instruction = it->second;
- worklist.push(instruction);
- schedule_worklist();
- }
- }
-
- TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence));
- return Status::OK();
-}
-
-Status VerifySchedule(
- const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& sequence) {
- VLOG(2) << "VerifySchedule()";
- XLA_VLOG_LINES(2, module.ToString());
- VLOG(2) << sequence;
-
- // Verify the set of computations in the sequence is exactly the set of
- // computations in the module.
- std::vector<HloComputation*> nonfusion_computations =
- module.MakeNonfusionComputations();
- TF_RET_CHECK(nonfusion_computations.size() == sequence.size());
- tensorflow::gtl::FlatSet<const HloComputation*> computations_in_module(
- module.computations().begin(), module.computations().end());
- for (const auto& computation_sequence : sequence) {
- TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1);
- }
-
- // For each computation verify the set of instructions is the same and that
- // each dependency and control edge is honored.
- for (const HloComputation* computation : nonfusion_computations) {
- tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
- int pos = 0;
- for (const HloInstruction* instruction : sequence.at(computation)) {
- TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
- << "Instruction " << instruction->name()
- << " appears more than once in the schedule";
- pos++;
- }
-
- TF_RET_CHECK(instruction_position.size() ==
- computation->instruction_count());
- for (const HloInstruction* instruction : computation->instructions()) {
- TF_RET_CHECK(instruction_position.count(instruction) == 1)
- << "Instruction " << instruction->name() << " is not in schedule";
- }
-
- for (const HloInstruction* instruction : computation->instructions()) {
- for (const HloInstruction* operand : instruction->operands()) {
- TF_RET_CHECK(instruction_position.at(operand) <
- instruction_position.at(instruction))
- << "Instruction " << instruction->name()
- << " is not scheduled after its operand " << operand->name();
- }
-
- for (const HloInstruction* pred : instruction->control_predecessors()) {
- TF_RET_CHECK(instruction_position.at(pred) <
- instruction_position.at(instruction))
- << "Instruction " << instruction->name()
- << " is not scheduled after its control predecessor "
- << pred->name();
- }
- }
- }
-
- return Status::OK();
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index d06b8d9a5c..54e32340ba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -32,14 +33,14 @@ namespace xla {
// 'computation' that minimizes peak memory, given a points-to analysis result
// that describes buffer aliasing, together with a target-specific size function
// that maps a tensor's logical size to its padded size.
-typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
+typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -47,7 +48,7 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
// DFS-order scheduler
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -55,7 +56,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
memory_by_computation);
// Naive Post Order scheduler
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -65,63 +66,26 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
// The default scheduling algorithm. Runs both the list scheduler
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation);
-// Returns an HloModuleSequence which seeks to minimize the memory required for
+// Returns an HloSchedule which seeks to minimize the memory required for
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {});
// Computes the schedule for a single computation.
// Currently only used by the GPU backend.
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
-// Transforms the given schedule such that it is (again) a valid schedule for
-// the module. This is used to update a schedule after the HLO module has been
-// transformed in some way. In general, the only transformations to the module
-// for which a schedule can be updated is the addition or removal of
-// instructions to/from the module. Updating the schedule after new dependencies
-// between existing instructions in the module is not supported and may result
-// in an error status returned.
-//
-// Instructions in the module which also exist in the given schedule will remain
-// in the same order in the updated schedule. Instructions which exist in the
-// module but not in the given schedule will be placed as early as possible in
-// the updated schedule.
-//
-// 'id_sequence' is a mirror of the given schedule 'sequence' but with
-// HloInstruction ids rather than HloInstruction pointers. This should be
-// constructed using ComputeIdSchedule below after the schedule is constructed
-// but before the HLO module is transformed.
-Status UpdateSchedule(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
- id_sequence,
- SequentialHloOrdering::HloModuleSequence* sequence);
-
-// Constructs a copy of the given schedule but with HloInstruction unique ids
-// rather than HloInstruction pointers. This is necessary for updating a
-// schedule as HloInstruction points in the schedule may become invalid if
-// instructions are removed from the module. Used by UpdateSchedule above..
-// TODO(b/113175018): Remove this function when HLO schedule is its own class.
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence);
-
-// Verifies that the given schedule is valid for the given module. Specifically,
-// the schedule contains exactly the instructions in the module and every
-// dependency in the module is satisfied in the schedule.
-Status VerifySchedule(const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& sequence);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index d49d09d459..6afe51997e 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -67,19 +68,20 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ const std::vector<const HloInstruction*>& sequence =
+ schedule.sequence(module->entry_computation()).instructions();
+ EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
- EXPECT_EQ(param, sequence.at(module->entry_computation()).front());
- EXPECT_EQ(sub, sequence.at(module->entry_computation()).back());
+ EXPECT_EQ(param, sequence.front());
+ EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(module.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
}
@@ -108,28 +110,26 @@ ENTRY root {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ const std::vector<const HloInstruction*>& sequence =
+ schedule.sequence(module->entry_computation()).instructions();
+ EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
std::unordered_map<string, const HloInstruction*> instructions_by_name;
- for (const HloInstruction* instruction :
- sequence.at(module->entry_computation())) {
+ for (const HloInstruction* instruction : sequence) {
instructions_by_name[instruction->name()] = instruction;
}
// The first instruction should be the parameter and the last the root.
- EXPECT_EQ(instructions_by_name.at("param"),
- sequence.at(module->entry_computation()).front());
- EXPECT_EQ(instructions_by_name.at("result"),
- sequence.at(module->entry_computation()).back());
+ EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
+ EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
// Instructions "d" and "e" will both be schedulable at the same time, but
// instruction "d" allows us to free the buffer of "p1", so the list scheduler
// should prefer it.
- SequentialHloOrdering ordering(module.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
instructions_by_name.at("e")));
}
@@ -220,13 +220,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
EXPECT_EQ(entry_computation->instruction_count(),
- sequence.at(entry_computation).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(entry_computation).size());
+ SequentialHloOrdering ordering(schedule);
// This schedule is an example of List's greedy heuristics being suboptimal.
// The while_loop is more expensive than transpose, so it would have been
// better to schedule it first, instead of during the busy time.
@@ -243,13 +243,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
// HeapSimulator doesn't account for subcomputations
EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
// HeapSimulator accounts for subcomputations. The output buffer is aliased,
// so we don't double count.
EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
@@ -281,19 +281,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(
- buffer.shape(), TUPLE_SIZE);
- },
- ListMemoryScheduler));
+ TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+ ScheduleModule(*module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), TUPLE_SIZE);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(module->entry_computation()).size());
+ SequentialHloOrdering ordering(schedule);
// tuple allocates the tuple buffer and doesn't free anything.
// abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
// abs_abs2 should be scheduled before tuple by List.
@@ -332,18 +331,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
auto fusion = computation->CreateFusionInstruction(
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
- TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(
- *module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
- },
- ListMemoryScheduler));
+ TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+ ScheduleModule(*module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), 2);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(module->entry_computation()).size());
+ SequentialHloOrdering ordering(schedule);
// fusion allocates memory for the tuple elements and doesn't free anything,
// so it's more expensive than exp.
EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
@@ -391,12 +390,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
- EXPECT_EQ(entry_computation->instruction_count(),
- sequence.at(entry_computation).size());
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ schedule.sequence(module->entry_computation()).size());
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
memory_by_computation[cond_computation] = 17;
@@ -406,262 +405,16 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
// HeapSimulator doesn't account for subcomputations
EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
// HeapSimulator accounts for subcomputations. Cond is the largest one.
// The output buffer of the while is aliased.
EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
-TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) {
- // Updating the schedule of an unchanged HLO module should not affect the
- // schedule at all.
- const string module_str = R"(
-HloModule UpdateScheduleUnchanged
-
-ENTRY main {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- c = f32[] constant(42.0)
- sum = f32[] add(a, b)
- neg = f32[] negate(c)
- ROOT root = f32[] multiply(sum, neg)
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
- std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second;
-
- EXPECT_EQ(entry_schedule.size(), 6);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(entry_schedule, sequence.begin()->second);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) {
- // Add some additional instructions to a module and verify the schedule can be
- // updated.
- const string module_str = R"(
-HloModule UpdateScheduleWithNewInstructions
-
-ENTRY main {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- c = f32[] constant(42.0)
- sum = f32[] add(a, b)
- neg = f32[] negate(c)
- ROOT root = f32[] multiply(sum, neg)
-}
-)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- HloComputation* entry = module->entry_computation();
- const Shape shape = entry->root_instruction()->shape();
- HloInstruction* constant = entry->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
- HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
- entry->set_root_instruction(sub);
-
- auto in_schedule = [&](const HloInstruction* hlo) {
- return std::find(sequence.at(entry).begin(), sequence.at(entry).end(),
- hlo) != sequence.at(entry).end();
- };
-
- EXPECT_EQ(sequence.at(entry).size(), 6);
- EXPECT_FALSE(in_schedule(constant));
- EXPECT_FALSE(in_schedule(sub));
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(entry).size(), 8);
- EXPECT_TRUE(in_schedule(constant));
- EXPECT_TRUE(in_schedule(sub));
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) {
- // Add and delete some instructions from a module and verify that the schedule
- // can be updated successfully.
- const string module_str = R"(
-HloModule UpdateScheduleWithAddedAndDeletedInstruction
-
-ENTRY main {
- a = f32[] parameter(0)
- b = f32[] parameter(1)
- c = f32[] constant(42.0)
- sum = f32[] add(a, b)
- neg = f32[] negate(c)
- ROOT root = f32[] multiply(sum, neg)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- // Set the entry root to some expression containing just a parameter and a
- // constant.
- HloComputation* entry = module->entry_computation();
- HloInstruction* constant = entry->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
- HloInstruction* new_root = entry->AddInstruction(
- HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
- constant, entry->parameter_instruction(0)));
- entry->set_root_instruction(new_root);
-
- // DCE should remove everything but the parameters and the newly added code.
- HloDCE dce;
- TF_ASSERT_OK(dce.Run(module.get()).status());
-
- EXPECT_EQ(sequence.at(entry).size(), 6);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(entry).size(), 4);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) {
- // Completely replace a module with an entirely new set of instructions and
- // verify that the schedule can be updated successfully.
- const string module_str = R"(
-HloModule UpdateScheduleWithCompletelyReplacedModule
-
-ENTRY main {
- a = f32[] constant(42.0)
- b = f32[] constant(123.0)
- ROOT sum = f32[] add(a, b)
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- // Replace the entry computation with the negation of a constant.
- HloComputation* entry = module->entry_computation();
- HloInstruction* constant = entry->AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
- HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kNegate, constant));
- entry->set_root_instruction(new_root);
-
- // DCE the old instructions.
- HloDCE dce;
- TF_ASSERT_OK(dce.Run(module.get()).status());
-
- EXPECT_EQ(sequence.at(entry).size(), 3);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(entry).size(), 2);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) {
- // Create changes to more than one computation in an HLO module and verify
- // that the schedule can be updated.
- const string module_str = R"(
-HloModule UpdateScheduleWithMultipleComputations
-
-%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
- %param.1 = (s32[], token[]) parameter(0)
- %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
- %constant.1 = s32[] constant(1)
- %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
- %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
- %after-all = token[] after-all(token[] %get-tuple-element.2)
- ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
-}
-
-%Cond (param: (s32[], token[])) -> pred[] {
- %param = (s32[], token[]) parameter(0)
- %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
- %constant = s32[] constant(42)
- ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
-}
-
-ENTRY %WhileLoop () -> s32[] {
- %zero = s32[] constant(0)
- %init_token = token[] after-all()
- %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
- %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
- ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
-}
-)";
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(),
- /*pointer_size=*/sizeof(void*));
- }));
- tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
- id_sequence = ComputeIdSchedule(sequence);
-
- const HloInstruction* xla_while =
- module->entry_computation()->root_instruction()->operand(0);
- HloComputation* body = xla_while->while_body();
- HloComputation* cond = xla_while->while_condition();
-
- // Negate the root of the cond.
- cond->set_root_instruction(cond->AddInstruction(
- HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
- HloOpcode::kNot, cond->root_instruction())));
-
- // Replace the body with a computation which just passes through its
- // parameter.
- body->set_root_instruction(body->parameter_instruction(0));
-
- // DCE the dead code in the body.
- HloDCE dce;
- TF_ASSERT_OK(dce.Run(module.get()).status());
-
- EXPECT_EQ(sequence.at(body).size(), 7);
- EXPECT_EQ(sequence.at(cond).size(), 4);
-
- TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
- TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
- EXPECT_EQ(sequence.at(body).size(), 1);
- EXPECT_EQ(sequence.at(cond).size(), 5);
-}
-
} // namespace
} // namespace xla